simplify setup process with WebUI

This commit is contained in:
yusing 2025-02-14 20:14:16 +08:00
parent 7047d37f70
commit 9f54f40f5a
21 changed files with 590 additions and 451 deletions

View file

@ -1,15 +0,0 @@
services:
godoxy-agent:
image: ghcr.io/yusing/godoxy-agent:latest
container_name: godoxy-agent
restart: always
network_mode: host # do not change this
environment:
AGENT_NAME: # defaults to hostname
AGENT_PORT: # defaults to 8890
# comma separated list of allowed main server IPs or CIDRs
# to register from this agent
REGISTRATION_ALLOWED_HOSTS:
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- ./certs:/app/certs # store Agent CA cert and Agent SSL cert

View file

@ -1,45 +1,45 @@
package main package main
import ( import (
"fmt" "os"
"github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/certs"
"github.com/yusing/go-proxy/agent/pkg/env" "github.com/yusing/go-proxy/agent/pkg/env"
"github.com/yusing/go-proxy/agent/pkg/server" "github.com/yusing/go-proxy/agent/pkg/server"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/pkg" "github.com/yusing/go-proxy/pkg"
"gopkg.in/yaml.v3"
) )
func printNewClientHelp() {
ip, ok := agent.MachineIP()
if !ok {
logging.Warn().Msg("No valid network interface found, change <machine-ip> to your actual IP")
ip = "<machine-ip>"
} else {
logging.Info().Msgf("Detected machine IP: %s, change if needed", ip)
}
host := fmt.Sprintf("%s:%d", ip, env.AgentPort)
cfgYAML, _ := yaml.Marshal(map[string]any{
"providers": map[string]any{
"agents": host,
},
})
logging.Info().Msgf("On main server, run:\n\ndocker exec godoxy /app/run add-agent '%s'\n", host)
logging.Info().Msgf("Then add this host (%s) to main server config like below:\n", host)
logging.Info().Msg(string(cfgYAML))
}
func main() { func main() {
ca, srv, isNew, err := certs.InitCerts() args := os.Args
if len(args) > 1 && args[1] == "migrate" {
if err := agent.MigrateFromOld(); err != nil {
E.LogFatal("failed to migrate from old docker compose", err)
}
return
}
_ = os.Chmod("/app/compose.yml", 0600)
ca := &agent.PEMPair{}
err := ca.Load(env.AgentCACert)
if err != nil { if err != nil {
E.LogFatal("init CA error", err) E.LogFatal("init CA error", err)
} }
caCert, err := ca.ToTLSCert()
if err != nil {
E.LogFatal("init CA error", err)
}
srv := &agent.PEMPair{}
srv.Load(env.AgentSSLCert)
if err != nil {
E.LogFatal("init SSL error", err)
}
srvCert, err := srv.ToTLSCert()
if err != nil {
E.LogFatal("init SSL error", err)
}
logging.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion()) logging.Info().Msgf("GoDoxy Agent version %s", pkg.GetVersion())
logging.Info().Msgf("Agent name: %s", env.AgentName) logging.Info().Msgf("Agent name: %s", env.AgentName)
@ -49,23 +49,15 @@ func main() {
Tips: Tips:
1. To change the agent name, you can set the AGENT_NAME environment variable. 1. To change the agent name, you can set the AGENT_NAME environment variable.
2. To change the agent port, you can set the AGENT_PORT environment variable. 2. To change the agent port, you can set the AGENT_PORT environment variable.
3. To skip the version check, you can set AGENT_SKIP_VERSION_CHECK to true.
4. If anything goes wrong, you can remove the 'certs' directory and start over.
`) `)
t := task.RootTask("agent", false) t := task.RootTask("agent", false)
opts := server.Options{ opts := server.Options{
CACert: ca, CACert: caCert,
ServerCert: srv, ServerCert: srvCert,
Port: env.AgentPort, Port: env.AgentPort,
} }
if isNew {
logging.Info().Msg("Initialization complete.")
printNewClientHelp()
server.StartRegistrationServer(t, opts)
}
server.StartAgentServer(t, opts) server.StartAgentServer(t, opts)
task.WaitExit(3) task.WaitExit(3)

View file

@ -0,0 +1,14 @@
services:
agent:
image: "{{.Image}}"
container_name: godoxy-agent
restart: always
network_mode: host # do not change this
environment:
AGENT_NAME: "{{.Name}}"
AGENT_PORT: "{{.Port}}"
AGENT_CA_CERT: "{{.CACert}}"
AGENT_SSL_CERT: "{{.SSLCert}}"
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- ./compose.yml:/app/compose.yml

View file

@ -38,7 +38,7 @@ const (
EndpointLogs = "/logs" EndpointLogs = "/logs"
EndpointSystemInfo = "/system_info" EndpointSystemInfo = "/system_info"
AgentHost = certs.CertsDNSName AgentHost = CertsDNSName
APIEndpointBase = "/godoxy/agent" APIEndpointBase = "/godoxy/agent"
APIBaseURL = "https://" + AgentHost + APIEndpointBase APIBaseURL = "https://" + AgentHost + APIEndpointBase
@ -80,20 +80,7 @@ func checkVersion(a, b string) bool {
return withoutBuildTime(a) == withoutBuildTime(b) return withoutBuildTime(a) == withoutBuildTime(b)
} }
func (cfg *AgentConfig) Start(parent task.Parent) E.Error { func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte) E.Error {
certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr))
if err != nil {
if os.IsNotExist(err) {
return E.Errorf("agents certs not found, did you run `godoxy new-agent %s ...`?", cfg.Addr)
}
return E.Wrap(err)
}
ca, crt, key, err := certs.ExtractCert(certData)
if err != nil {
return E.Wrap(err)
}
clientCert, err := tls.X509KeyPair(crt, key) clientCert, err := tls.X509KeyPair(crt, key)
if err != nil { if err != nil {
return E.Wrap(err) return E.Wrap(err)
@ -109,6 +96,7 @@ func (cfg *AgentConfig) Start(parent task.Parent) E.Error {
cfg.tlsConfig = &tls.Config{ cfg.tlsConfig = &tls.Config{
Certificates: []tls.Certificate{clientCert}, Certificates: []tls.Certificate{clientCert},
RootCAs: caCertPool, RootCAs: caCertPool,
ServerName: CertsDNSName,
} }
// create transport and http client // create transport and http client
@ -140,6 +128,23 @@ func (cfg *AgentConfig) Start(parent task.Parent) E.Error {
return nil return nil
} }
func (cfg *AgentConfig) Start(parent task.Parent) E.Error {
certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr))
if err != nil {
if os.IsNotExist(err) {
return E.Errorf("agents certs not found, did you run `godoxy new-agent %s ...`?", cfg.Addr)
}
return E.Wrap(err)
}
ca, crt, key, err := certs.ExtractCert(certData)
if err != nil {
return E.Wrap(err)
}
return cfg.StartWithCerts(parent, ca, crt, key)
}
func (cfg *AgentConfig) NewHTTPClient() *http.Client { func (cfg *AgentConfig) NewHTTPClient() *http.Client {
return &http.Client{ return &http.Client{
Transport: cfg.Transport(), Transport: cfg.Transport(),

View file

@ -0,0 +1,123 @@
package agent
import (
"bytes"
"os"
"path"
"strconv"
"text/template"
_ "embed"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
"gopkg.in/yaml.v3"
)
//go:embed agent.compose.yml
var agentComposeYAML []byte
var agentComposeYAMLTemplate = template.Must(template.New("agent.compose.yml").Parse(string(agentComposeYAML)))
const (
DockerImageProduction = "ghcr.io/yusing/godoxy-agent:latest"
DockerImageNightly = "yusing/godoxy-agent-nightly:latest"
)
type (
AgentComposeConfig struct {
Image string
Name string
Port int
CACert string
SSLCert string
}
)
func (c *AgentComposeConfig) Generate() (string, error) {
buf := bytes.NewBuffer(make([]byte, 0, 1024))
if err := agentComposeYAMLTemplate.Execute(buf, c); err != nil {
return "", err
}
return buf.String(), nil
}
func pemPairFromFile(path string) (*PEMPair, error) {
cert, err := os.ReadFile(path + ".crt")
if err != nil {
return nil, err
}
key, err := os.ReadFile(path + ".key")
if err != nil {
return nil, err
}
return &PEMPair{
Cert: cert,
Key: key,
}, nil
}
func rmOldCerts(p string) error {
files, err := utils.ListFiles(p, 0)
if err != nil {
return err
}
for _, file := range files {
if err := os.Remove(path.Join(p, file)); err != nil {
return err
}
}
return nil
}
type dockerCompose struct {
Services struct {
GodoxyAgent struct {
Environment struct {
AGENT_NAME string `yaml:"GODOXY_AGENT_NAME"`
AGENT_PORT string `yaml:"GODOXY_AGENT_PORT"`
} `yaml:"environment"`
} `yaml:"godoxy-agent"`
} `yaml:"services"`
}
// TODO: remove this
func MigrateFromOld() error {
oldCompose, err := os.ReadFile("/app/compose.yml")
if err != nil {
return err
}
var compose dockerCompose
if err := yaml.Unmarshal(oldCompose, &compose); err != nil {
return err
}
ca, err := pemPairFromFile("/app/certs/ca")
if err != nil {
return err
}
agentCert, err := pemPairFromFile("/app/certs/agent")
if err != nil {
return err
}
var composeConfig AgentComposeConfig
composeConfig.Image = DockerImageNightly
composeConfig.Name = compose.Services.GodoxyAgent.Environment.AGENT_NAME
composeConfig.Port, err = strconv.Atoi(compose.Services.GodoxyAgent.Environment.AGENT_PORT) // ignore error, empty is fine
if composeConfig.Port == 0 {
composeConfig.Port = 8890
}
composeConfig.CACert = ca.String()
composeConfig.SSLCert = agentCert.String()
composeTemplate, err := composeConfig.Generate()
if err != nil {
return E.Wrap(err, "failed to generate new docker compose")
}
if err := os.WriteFile("/app/compose.yml", []byte(composeTemplate), 0600); err != nil {
return E.Wrap(err, "failed to write new docker compose")
}
logging.Info().Msg("Migrated from old docker compose:")
logging.Info().Msg(composeTemplate)
return nil
}

View file

@ -0,0 +1,139 @@
package agent
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"errors"
"math/big"
"strings"
"time"
)
const (
CertsDNSName = "godoxy.agent"
KeySize = 2048
)
func toPEMPair(certDER []byte, key *rsa.PrivateKey) *PEMPair {
return &PEMPair{
Cert: pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}),
Key: pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)}),
}
}
func b64Encode(data []byte) string {
return base64.StdEncoding.EncodeToString(data)
}
func b64Decode(data string) ([]byte, error) {
return base64.StdEncoding.DecodeString(data)
}
type PEMPair struct {
Cert, Key []byte
}
func (p *PEMPair) String() string {
return b64Encode(p.Cert) + ";" + b64Encode(p.Key)
}
func (p *PEMPair) Load(data string) (err error) {
parts := strings.Split(data, ";")
if len(parts) != 2 {
return errors.New("invalid PEM pair")
}
p.Cert, err = b64Decode(parts[0])
if err != nil {
return err
}
p.Key, err = b64Decode(parts[1])
if err != nil {
return err
}
return nil
}
func (p *PEMPair) ToTLSCert() (*tls.Certificate, error) {
cert, err := tls.X509KeyPair(p.Cert, p.Key)
return &cert, err
}
func NewAgent() (ca, srv, client *PEMPair, err error) {
// Create the CA's certificate
caTemplate := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"GoDoxy"},
CommonName: CertsDNSName,
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caKey, err := rsa.GenerateKey(rand.Reader, KeySize)
if err != nil {
return nil, nil, nil, err
}
caDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey)
if err != nil {
return nil, nil, nil, err
}
ca = toPEMPair(caDER, caKey)
// Generate a new private key for the server certificate
serverKey, err := rsa.GenerateKey(rand.Reader, KeySize)
if err != nil {
return nil, nil, nil, err
}
srvTemplate := &x509.Certificate{
SerialNumber: big.NewInt(2),
Issuer: caTemplate.Subject,
Subject: caTemplate.Subject,
DNSNames: []string{CertsDNSName},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1000, 0, 0), // Add validity period
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
srvCertDER, err := x509.CreateCertificate(rand.Reader, srvTemplate, caTemplate, &serverKey.PublicKey, caKey)
if err != nil {
return nil, nil, nil, err
}
srv = toPEMPair(srvCertDER, serverKey)
clientKey, err := rsa.GenerateKey(rand.Reader, KeySize)
if err != nil {
return nil, nil, nil, err
}
clientTemplate := &x509.Certificate{
SerialNumber: big.NewInt(3),
Issuer: caTemplate.Subject,
Subject: caTemplate.Subject,
DNSNames: []string{CertsDNSName},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1000, 0, 0),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
clientCertDER, err := x509.CreateCertificate(rand.Reader, clientTemplate, caTemplate, &clientKey.PublicKey, caKey)
if err != nil {
return nil, nil, nil, err
}
client = toPEMPair(clientCertDER, clientKey)
return
}

View file

@ -0,0 +1,91 @@
package agent
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"net/http/httptest"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestNewAgent(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
ExpectTrue(t, ca != nil)
ExpectTrue(t, srv != nil)
ExpectTrue(t, client != nil)
}
func TestPEMPair(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
for i, p := range []*PEMPair{ca, srv, client} {
t.Run(fmt.Sprintf("load-%d", i), func(t *testing.T) {
var pp PEMPair
err := pp.Load(p.String())
ExpectNoError(t, err)
ExpectBytesEqual(t, p.Cert, pp.Cert)
ExpectBytesEqual(t, p.Key, pp.Key)
})
}
}
func TestPEMPairToTLSCert(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
for i, p := range []*PEMPair{ca, srv, client} {
t.Run(fmt.Sprintf("toTLSCert-%d", i), func(t *testing.T) {
cert, err := p.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, cert != nil)
})
}
}
func TestServerClient(t *testing.T) {
ca, srv, client, err := NewAgent()
ExpectNoError(t, err)
srvTLS, err := srv.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, srvTLS != nil)
clientTLS, err := client.ToTLSCert()
ExpectNoError(t, err)
ExpectTrue(t, clientTLS != nil)
caPool := x509.NewCertPool()
ExpectTrue(t, caPool.AppendCertsFromPEM(ca.Cert))
srvTLSConfig := &tls.Config{
Certificates: []tls.Certificate{*srvTLS},
ClientCAs: caPool,
ClientAuth: tls.RequireAndVerifyClientCert,
}
clientTLSConfig := &tls.Config{
Certificates: []tls.Certificate{*clientTLS},
RootCAs: caPool,
ServerName: CertsDNSName,
}
server := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
server.TLS = srvTLSConfig
server.StartTLS()
defer server.Close()
httpClient := &http.Client{
Transport: &http.Transport{TLSClientConfig: clientTLSConfig},
}
resp, err := httpClient.Get(server.URL)
ExpectNoError(t, err)
ExpectEqual(t, resp.StatusCode, http.StatusOK)
}

View file

@ -1,30 +0,0 @@
package agent
import (
"net"
"strings"
)
func MachineIP() (string, bool) {
interfaces, err := net.Interfaces()
if err != nil {
interfaces = []net.Interface{}
}
for _, in := range interfaces {
addrs, err := in.Addrs()
if err != nil {
continue
}
if !strings.HasPrefix(in.Name, "eth") && !strings.HasPrefix(in.Name, "en") {
continue
}
for _, addr := range addrs {
if ipnet, ok := addr.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
return ipnet.IP.String(), true
}
}
}
}
return "", false
}

View file

@ -1,201 +0,0 @@
package certs
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"errors"
"math/big"
"os"
"time"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils"
)
const (
CertsDNSName = "godoxy.agent"
caCertPath = "certs/ca.crt"
caKeyPath = "certs/ca.key"
srvCertPath = "certs/agent.crt"
srvKeyPath = "certs/agent.key"
)
func loadCerts(certPath, keyPath string) (*tls.Certificate, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
return &cert, err
}
func write(b []byte, f *os.File) error {
_, err := f.Write(b)
return err
}
func saveCerts(certDER []byte, key *rsa.PrivateKey, certPath, keyPath string) ([]byte, []byte, error) {
certPEM, keyPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}),
pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
if certPath == "" || keyPath == "" {
return certPEM, keyPEM, nil
}
certFile, err := os.Create(certPath)
if err != nil {
return nil, nil, err
}
defer certFile.Close()
keyFile, err := os.Create(keyPath)
if err != nil {
return nil, nil, err
}
defer keyFile.Close()
return certPEM, keyPEM, errors.Join(
write(certPEM, certFile),
write(keyPEM, keyFile),
)
}
func checkExists(certPath, keyPath string) bool {
certExists, err := utils.FileExists(certPath)
if err != nil {
E.LogFatal("cert error", err)
}
keyExists, err := utils.FileExists(keyPath)
if err != nil {
E.LogFatal("key error", err)
}
return certExists && keyExists
}
func InitCerts() (ca *tls.Certificate, srv *tls.Certificate, isNew bool, err error) {
if checkExists(caCertPath, caKeyPath) && checkExists(srvCertPath, srvKeyPath) {
logging.Info().Msg("Loading existing certs...")
ca, err = loadCerts(caCertPath, caKeyPath)
if err != nil {
return nil, nil, false, err
}
srv, err = loadCerts(srvCertPath, srvKeyPath)
if err != nil {
return nil, nil, false, err
}
logging.Info().Msg("Verifying agent cert...")
roots := x509.NewCertPool()
roots.AddCert(ca.Leaf)
srvCert, err := x509.ParseCertificate(srv.Certificate[0])
if err != nil {
return nil, nil, false, err
}
// check if srv is signed by ca
if _, err := srvCert.Verify(x509.VerifyOptions{
Roots: roots,
}); err == nil {
logging.Info().Msg("OK")
return ca, srv, false, nil
}
logging.Error().Msg("Agent cert and CA cert mismatch, regenerating")
}
// Create the CA's certificate
caTemplate := x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
Organization: []string{"GoDoxy"},
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
caKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, false, err
}
caCertDER, err := x509.CreateCertificate(rand.Reader, &caTemplate, &caTemplate, &caKey.PublicKey, caKey)
if err != nil {
return nil, nil, false, err
}
certPEM, keyPEM, err := saveCerts(caCertDER, caKey, caCertPath, caKeyPath)
if err != nil {
return nil, nil, false, err
}
caCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, nil, false, err
}
ca = &caCert
// Generate a new private key for the server certificate
serverKey, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return nil, nil, false, err
}
srvTemplate := caTemplate
srvTemplate.Issuer = srvTemplate.Subject
srvTemplate.DNSNames = append(srvTemplate.DNSNames, CertsDNSName)
srvCertDER, err := x509.CreateCertificate(rand.Reader, &srvTemplate, &caTemplate, &serverKey.PublicKey, caKey)
if err != nil {
return nil, nil, false, err
}
certPEM, keyPEM, err = saveCerts(srvCertDER, serverKey, srvCertPath, srvKeyPath)
if err != nil {
return nil, nil, false, err
}
agentCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return nil, nil, false, err
}
srv = &agentCert
return ca, srv, true, nil
}
func NewClientCert(ca *tls.Certificate) ([]byte, []byte, error) {
// Generate the SSL's private key
sslKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, nil, err
}
// Create the SSL's certificate
sslTemplate := &x509.Certificate{
SerialNumber: big.NewInt(2),
Subject: pkix.Name{
Organization: []string{"GoDoxy"},
CommonName: CertsDNSName,
},
NotBefore: time.Now(),
NotAfter: time.Now().AddDate(1000, 0, 0), // 1000 years
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
}
// Sign the certificate with the CA
sslCertDER, err := x509.CreateCertificate(rand.Reader, sslTemplate, ca.Leaf, &sslKey.PublicKey, ca.PrivateKey)
if err != nil {
return nil, nil, err
}
return saveCerts(sslCertDER, sslKey, "", "")
}

View file

@ -12,7 +12,7 @@ import (
func writeFile(zipWriter *zip.Writer, name string, data []byte) error { func writeFile(zipWriter *zip.Writer, name string, data []byte) error {
w, err := zipWriter.CreateHeader(&zip.FileHeader{ w, err := zipWriter.CreateHeader(&zip.FileHeader{
Name: name, Name: name,
Method: zip.Deflate, Method: zip.Store,
}) })
if err != nil { if err != nil {
return err return err
@ -31,8 +31,7 @@ func readFile(f *zip.File) ([]byte, error) {
} }
func ZipCert(ca, crt, key []byte) ([]byte, error) { func ZipCert(ca, crt, key []byte) ([]byte, error) {
data := bytes.NewBuffer(nil) data := bytes.NewBuffer(make([]byte, 0, 6144))
data.Grow(6144)
zipWriter := zip.NewWriter(data) zipWriter := zip.NewWriter(data)
defer zipWriter.Close() defer zipWriter.Close()

56
agent/pkg/env/env.go vendored
View file

@ -1,11 +1,7 @@
package env package env
import ( import (
"log"
"net"
"os" "os"
"strings"
"sync"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
) )
@ -24,54 +20,6 @@ var (
AgentRegistrationPort = common.GetEnvInt("AGENT_REGISTRATION_PORT", 8891) AgentRegistrationPort = common.GetEnvInt("AGENT_REGISTRATION_PORT", 8891)
AgentSkipClientCertCheck = common.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false) AgentSkipClientCertCheck = common.GetEnvBool("AGENT_SKIP_CLIENT_CERT_CHECK", false)
RegistrationAllowedHosts = common.GetCommaSepEnv("REGISTRATION_ALLOWED_HOSTS", "") AgentCACert = common.GetEnvString("AGENT_CA_CERT", "")
RegistrationAllowedCIDRs []*net.IPNet AgentSSLCert = common.GetEnvString("AGENT_SSL_CERT", "")
) )
func init() {
cidrs, err := toCIDRs(RegistrationAllowedHosts)
if err != nil {
log.Fatalf("failed to parse allowed hosts: %v", err)
}
RegistrationAllowedCIDRs = cidrs
}
func toCIDRs(hosts []string) ([]*net.IPNet, error) {
cidrs := make([]*net.IPNet, 0, len(hosts))
for _, host := range hosts {
if !strings.Contains(host, "/") {
host += "/32"
}
_, cidr, err := net.ParseCIDR(host)
if err != nil {
return nil, err
}
cidrs = append(cidrs, cidr)
}
return cidrs, nil
}
var warnOnce sync.Once
func IsAllowedHost(remoteAddr string) bool {
if len(RegistrationAllowedCIDRs) == 0 {
warnOnce.Do(func() {
log.Println("Warning: REGISTRATION_ALLOWED_HOSTS is empty, allowing all hosts")
})
return true
}
ip, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
ip = remoteAddr
}
netIP := net.ParseIP(ip)
if netIP == nil {
return false
}
for _, cidr := range RegistrationAllowedCIDRs {
if cidr.Contains(netIP) {
return true
}
}
return false
}

View file

@ -1,22 +1,15 @@
package handler package handler
import ( import (
"crypto/tls"
"encoding/pem"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/certs"
"github.com/yusing/go-proxy/agent/pkg/env" "github.com/yusing/go-proxy/agent/pkg/env"
v1 "github.com/yusing/go-proxy/internal/api/v1" v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/utils"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/systeminfo" "github.com/yusing/go-proxy/internal/metrics/systeminfo"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -54,46 +47,3 @@ func NewAgentHandler() http.Handler {
mux.ServeMux.HandleFunc("/", DockerSocketHandler()) mux.ServeMux.HandleFunc("/", DockerSocketHandler())
return mux return mux
} }
// NewRegistrationHandler creates a new registration handler
// It checks if the request is coming from an allowed host
// Generates a new client certificate and zips it
// Sends the zipped certificate to the client
// its run only once on agent first start.
func NewRegistrationHandler(task *task.Task, ca *tls.Certificate) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
if !env.IsAllowedHost(r.RemoteAddr) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
return
}
if r.URL.Path == "/done" {
logging.Info().Msg("registration done")
task.Finish(nil)
w.WriteHeader(http.StatusOK)
return
}
logging.Info().Msgf("received registration request from %s", r.RemoteAddr)
caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: ca.Certificate[0]})
crt, key, err := certs.NewClientCert(ca)
if err != nil {
utils.HandleErr(w, r, E.Wrap(err, "failed to generate client certificate"))
return
}
zipped, err := certs.ZipCert(caPEM, crt, key)
if err != nil {
utils.HandleErr(w, r, E.Wrap(err, "failed to zip certificate"))
return
}
w.Header().Set("Content-Type", "application/zip")
if _, err := w.Write(zipped); err != nil {
logging.Error().Err(err).Msg("failed to respond to registration request")
return
}
}
}

View file

@ -77,32 +77,3 @@ func StartAgentServer(parent task.Parent, opt Options) {
} }
}() }()
} }
func StartRegistrationServer(parent task.Parent, opt Options) {
t := parent.Subtask("registration_server")
logger := logging.GetLogger()
registrationServer := &http.Server{
Addr: fmt.Sprintf(":%d", opt.Port),
Handler: handler.NewRegistrationHandler(t, opt.CACert),
ErrorLog: log.New(logger, "", 0),
}
go func() {
err := registrationServer.ListenAndServe()
server.HandleError(logger, err, "failed to serve registration server")
}()
logging.Info().Int("port", opt.Port).Msg("registration server started")
defer t.Finish(nil)
<-t.Context().Done()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
err := registrationServer.Shutdown(ctx)
server.HandleError(logger, err, "failed to shutdown registration server")
logging.Info().Int("port", opt.Port).Msg("registration server stopped")
}

View file

@ -77,7 +77,9 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
mux.HandleFunc("GET", "/v1/logs", memlogger.Handler(), true) mux.HandleFunc("GET", "/v1/logs", memlogger.Handler(), true)
mux.HandleFunc("GET", "/v1/favicon", favicon.GetFavIcon, true) mux.HandleFunc("GET", "/v1/favicon", favicon.GetFavIcon, true)
mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true) mux.HandleFunc("POST", "/v1/homepage/set", v1.SetHomePageOverrides, true)
mux.HandleFunc("GET", "/v1/agents", v1.AgentsWS, true) mux.HandleFunc("GET", "/v1/agents", v1.ListAgents, true)
mux.HandleFunc("GET", "/v1/agents/new", v1.NewAgent, true)
mux.HandleFunc("POST", "/v1/agents/add", v1.AddAgent, true)
mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true) mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true)
mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true) mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true)

View file

@ -11,7 +11,7 @@ import (
"github.com/yusing/go-proxy/internal/net/http/httpheaders" "github.com/yusing/go-proxy/internal/net/http/httpheaders"
) )
func AgentsWS(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func ListAgents(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
if httpheaders.IsWebsocket(r.Header) { if httpheaders.IsWebsocket(r.Header) {
U.PeriodicWS(w, r, 10*time.Second, func(conn *websocket.Conn) error { U.PeriodicWS(w, r, 10*time.Second, func(conn *websocket.Conn) error {
wsjson.Write(r.Context(), conn, cfg.ListAgents()) wsjson.Write(r.Context(), conn, cfg.ListAgents())

View file

@ -28,7 +28,6 @@ const (
ListHomepageCategories = "homepage_categories" ListHomepageCategories = "homepage_categories"
ListIcons = "icons" ListIcons = "icons"
ListTasks = "tasks" ListTasks = "tasks"
ListAgents = "agents"
) )
func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) { func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
@ -78,8 +77,6 @@ func List(cfg config.ConfigInstance, w http.ResponseWriter, r *http.Request) {
U.RespondJSON(w, r, icons) U.RespondJSON(w, r, icons)
case ListTasks: case ListTasks:
U.RespondJSON(w, r, task.DebugTaskList()) U.RespondJSON(w, r, task.DebugTaskList())
case ListAgents:
U.RespondJSON(w, r, cfg.ListAgents())
default: default:
U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest) U.HandleErr(w, r, U.ErrInvalidKey("what"), http.StatusBadRequest)
} }

View file

@ -0,0 +1,135 @@
package v1
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"strconv"
_ "embed"
"github.com/yusing/go-proxy/agent/pkg/agent"
"github.com/yusing/go-proxy/agent/pkg/certs"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
func NewAgent(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
name := q.Get("name")
if name == "" {
U.RespondError(w, U.ErrMissingKey("name"))
return
}
host := q.Get("host")
if host == "" {
U.RespondError(w, U.ErrMissingKey("host"))
return
}
portStr := q.Get("port")
if portStr == "" {
U.RespondError(w, U.ErrMissingKey("port"))
return
}
port, err := strconv.Atoi(portStr)
if err != nil || port < 1 || port > 65535 {
U.RespondError(w, U.ErrInvalidKey("port"))
return
}
hostport := fmt.Sprintf("%s:%d", host, port)
if _, ok := config.GetInstance().GetAgent(hostport); ok {
U.RespondError(w, U.ErrAlreadyExists("agent", hostport), http.StatusConflict)
return
}
t := q.Get("type")
switch t {
case "docker":
break
case "system":
U.RespondError(w, U.Errorf("system agent is not supported yet"), http.StatusNotImplemented)
return
case "":
U.RespondError(w, U.ErrMissingKey("type"))
return
default:
U.RespondError(w, U.ErrInvalidKey("type"))
return
}
nightly := strutils.ParseBool(q.Get("nightly"))
var image string
if nightly {
image = agent.DockerImageNightly
} else {
image = agent.DockerImageProduction
}
ca, srv, client, err := agent.NewAgent()
if err != nil {
U.HandleErr(w, r, err)
return
}
cfg := agent.AgentComposeConfig{
Image: image,
Name: name,
Port: port,
CACert: ca.String(),
SSLCert: srv.String(),
}
template, err := cfg.Generate()
if err != nil {
U.HandleErr(w, r, err)
return
}
U.RespondJSON(w, r, map[string]any{
"compose": template,
"ca": ca,
"client": client,
})
}
func AddAgent(w http.ResponseWriter, r *http.Request) {
defer r.Body.Close()
clientPEMData, err := io.ReadAll(r.Body)
if err != nil {
U.HandleErr(w, r, err)
return
}
var data struct {
Host string `json:"host"`
CA agent.PEMPair `json:"ca"`
Client agent.PEMPair `json:"client"`
}
if err := json.Unmarshal(clientPEMData, &data); err != nil {
U.RespondError(w, err, http.StatusBadRequest)
return
}
nRoutesAdded, err := config.GetInstance().AddAgent(data.Host, data.CA, data.Client)
if err != nil {
U.RespondError(w, err)
return
}
zip, err := certs.ZipCert(data.CA.Cert, data.Client.Cert, data.Client.Key)
if err != nil {
U.HandleErr(w, r, err)
return
}
if err := os.WriteFile(certs.AgentCertsFilename(data.Host), zip, 0600); err != nil {
U.HandleErr(w, r, err)
return
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(fmt.Sprintf("Added %d routes", nRoutesAdded)))
}

View file

@ -2,7 +2,6 @@ package utils
import ( import (
"context" "context"
"encoding/json"
"errors" "errors"
"net/http" "net/http"
"syscall" "syscall"
@ -42,25 +41,26 @@ func RespondError(w http.ResponseWriter, err error, code ...int) {
if len(code) == 0 { if len(code) == 0 {
code = []int{http.StatusBadRequest} code = []int{http.StatusBadRequest}
} }
buf, err := json.Marshal(err) w.Header().Set("Content-Type", "text/plain; charset=utf-8")
if err != nil { // just in case http.Error(w, ansi.StripANSI(err.Error()), code[0])
w.Header().Set("Content-Type", "text/plain; charset=utf-8") }
http.Error(w, ansi.StripANSI(err.Error()), code[0])
return func Errorf(format string, args ...any) error {
} return E.Errorf(format, args...)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(code[0])
_, _ = w.Write(buf)
} }
func ErrMissingKey(k string) error { func ErrMissingKey(k string) error {
return E.New("missing key '" + k + "' in query or request body") return E.New(k + " is required")
} }
func ErrInvalidKey(k string) error { func ErrInvalidKey(k string) error {
return E.New("invalid key '" + k + "' in query or request body") return E.New(k + " is invalid")
}
func ErrAlreadyExists(k, v string) error {
return E.Errorf("%s %q already exists", k, v)
} }
func ErrNotFound(k, v string) error { func ErrNotFound(k, v string) error {
return E.Errorf("key %q with value %q not found", k, v) return E.Errorf("%s %q not found", k, v)
} }

View file

@ -2,6 +2,8 @@ package config
import ( import (
"github.com/yusing/go-proxy/agent/pkg/agent" "github.com/yusing/go-proxy/agent/pkg/agent"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/route/provider"
"github.com/yusing/go-proxy/internal/utils/functional" "github.com/yusing/go-proxy/internal/utils/functional"
) )
@ -27,6 +29,22 @@ func (cfg *Config) GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, b
return GetAgent(agent.GetAgentAddrFromDockerHost(agentAddrOrDockerHost)) return GetAgent(agent.GetAgentAddrFromDockerHost(agentAddrOrDockerHost))
} }
func (cfg *Config) AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, E.Error) {
var agentCfg agent.AgentConfig
agentCfg.Addr = host
err := agentCfg.StartWithCerts(cfg.Task(), ca.Cert, client.Cert, client.Key)
if err != nil {
return 0, err
}
provider := provider.NewAgentProvider(&agentCfg)
if err := cfg.errIfExists(provider); err != nil {
return 0, err
}
cfg.storeProvider(provider)
return provider.NumRoutes(), nil
}
func (cfg *Config) ListAgents() []*agent.AgentConfig { func (cfg *Config) ListAgents() []*agent.AgentConfig {
agents := make([]*agent.AgentConfig, 0, agentPool.Size()) agents := make([]*agent.AgentConfig, 0, agentPool.Size())
agentPool.RangeAll(func(key string, value *agent.AgentConfig) { agentPool.RangeAll(func(key string, value *agent.AgentConfig) {

View file

@ -42,6 +42,7 @@ type (
RouteProviderList() []string RouteProviderList() []string
Context() context.Context Context() context.Context
GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool) GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool)
AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, E.Error)
ListAgents() []*agent.AgentConfig ListAgents() []*agent.AgentConfig
} }
) )

View file

@ -130,7 +130,7 @@ func (s *Server) stop() {
return return
} }
ctx, cancel := context.WithTimeout(task.RootContext(), 3*time.Second) ctx, cancel := context.WithTimeout(task.RootContext(), 5*time.Second)
defer cancel() defer cancel()
if s.http != nil && s.httpStarted { if s.http != nil && s.httpStarted {