From 16b046bd444fd737b0a18f143a43b2cd7f3ddae7 Mon Sep 17 00:00:00 2001 From: yusing Date: Sat, 15 Feb 2025 21:50:34 +0800 Subject: [PATCH] add cert info and renewal api --- agent/pkg/agent/config.go | 13 ++--- agent/pkg/server/server.go | 2 - go.mod | 2 +- go.sum | 4 +- internal/api/handler.go | 12 +++-- internal/api/v1/certapi/cert_info.go | 41 +++++++++++++++ internal/api/v1/certapi/renew.go | 56 +++++++++++++++++++++ internal/api/v1/{file.go => config_file.go} | 0 internal/autocert/config.go | 4 +- internal/autocert/constants.go | 2 + internal/autocert/provider.go | 16 +++++- internal/config/config.go | 8 +-- internal/config/types/config.go | 1 + internal/config/types/config_test.go | 42 ++++++++++++++++ internal/logging/memlogger/mem_logger.go | 45 +++++++++-------- 15 files changed, 201 insertions(+), 47 deletions(-) create mode 100644 internal/api/v1/certapi/cert_info.go create mode 100644 internal/api/v1/certapi/renew.go rename internal/api/v1/{file.go => config_file.go} (100%) create mode 100644 internal/config/types/config_test.go diff --git a/agent/pkg/agent/config.go b/agent/pkg/agent/config.go index 6849342..c7aa704 100644 --- a/agent/pkg/agent/config.go +++ b/agent/pkg/agent/config.go @@ -90,7 +90,7 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte) caCertPool := x509.NewCertPool() ok := caCertPool.AppendCertsFromPEM(ca) if !ok { - return gperr.New("invalid CA certificate") + return gperr.New("invalid ca certificate") } cfg.tlsConfig = &tls.Config{ @@ -128,21 +128,18 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte) return nil } -func (cfg *AgentConfig) Start(parent task.Parent) error { +func (cfg *AgentConfig) Start(parent task.Parent) gperr.Error { certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr)) if err != nil { - if os.IsNotExist(err) { - return gperr.Errorf("agents certs not found, did you run `godoxy new-agent %s ...`?", cfg.Addr) - } - return gperr.Wrap(err) + return gperr.Wrap(err, "failed to read agent certs") } ca, crt, key, err := certs.ExtractCert(certData) if err != nil { - return gperr.Wrap(err) + return gperr.Wrap(err, "failed to extract agent certs") } - return cfg.StartWithCerts(parent, ca, crt, key) + return gperr.Wrap(cfg.StartWithCerts(parent, ca, crt, key)) } func (cfg *AgentConfig) NewHTTPClient() *http.Client { diff --git a/agent/pkg/server/server.go b/agent/pkg/server/server.go index 736fd81..fb0910a 100644 --- a/agent/pkg/server/server.go +++ b/agent/pkg/server/server.go @@ -6,7 +6,6 @@ import ( "crypto/x509" "encoding/pem" "fmt" - "log" "net" "net/http" "time" @@ -45,7 +44,6 @@ func StartAgentServer(parent task.Parent, opt Options) { agentServer := &http.Server{ Handler: handler.NewAgentHandler(), TLSConfig: tlsConfig, - ErrorLog: log.New(logger, "", 0), } go func() { diff --git a/go.mod b/go.mod index 2ffee93..5d2dd3a 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,7 @@ module github.com/yusing/go-proxy go 1.24.0 require ( - github.com/PuerkitoBio/goquery v1.10.1 + github.com/PuerkitoBio/goquery v1.10.2 github.com/coder/websocket v1.8.12 github.com/coreos/go-oidc/v3 v3.12.0 github.com/docker/cli v27.5.1+incompatible diff --git a/go.sum b/go.sum index 7453ef2..9d8f3ad 100644 --- a/go.sum +++ b/go.sum @@ -2,8 +2,8 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= -github.com/PuerkitoBio/goquery v1.10.1 h1:Y8JGYUkXWTGRB6Ars3+j3kN0xg1YqqlwvdTV8WTFQcU= -github.com/PuerkitoBio/goquery v1.10.1/go.mod h1:IYiHrOMps66ag56LEH7QYDDupKXyo5A8qrjIx3ZtujY= +github.com/PuerkitoBio/goquery v1.10.2 h1:7fh2BdHcG6VFZsK7toXBT/Bh1z5Wmy8Q9MV9HqT2AM8= +github.com/PuerkitoBio/goquery v1.10.2/go.mod h1:0guWGjcLu9AYC7C1GHnpysHy056u9aEkUHwhdnePMCU= github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= diff --git a/internal/api/handler.go b/internal/api/handler.go index 9de8651..7e05682 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -7,6 +7,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" v1 "github.com/yusing/go-proxy/internal/api/v1" "github.com/yusing/go-proxy/internal/api/v1/auth" + "github.com/yusing/go-proxy/internal/api/v1/certapi" "github.com/yusing/go-proxy/internal/api/v1/favicon" "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" @@ -54,9 +55,12 @@ func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...b if len(requireAuth) > 0 && requireAuth[0] { handler = auth.RequireAuth(handler) } - - for _, m := range strutils.CommaSeperatedList(methods) { - mux.ServeMux.HandleFunc(m+" "+endpoint, handler) + if methods == "" { + mux.ServeMux.HandleFunc(endpoint, handler) + } else { + for _, m := range strutils.CommaSeperatedList(methods) { + mux.ServeMux.HandleFunc(m+" "+endpoint, handler) + } } } @@ -82,6 +86,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler { mux.HandleFunc("POST", "/v1/agents/add", v1.AddAgent, true) mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true) mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true) + mux.HandleFunc("GET", "/v1/cert/info", certapi.GetCertInfo, true) + mux.HandleFunc("", "/v1/cert/renew", certapi.RenewCert, true) if common.PrometheusEnabled { mux.Handle("GET /v1/metrics", promhttp.Handler()) diff --git a/internal/api/v1/certapi/cert_info.go b/internal/api/v1/certapi/cert_info.go new file mode 100644 index 0000000..07edfd9 --- /dev/null +++ b/internal/api/v1/certapi/cert_info.go @@ -0,0 +1,41 @@ +package certapi + +import ( + "encoding/json" + "net/http" + + config "github.com/yusing/go-proxy/internal/config/types" +) + +type CertInfo struct { + Subject string `json:"subject"` + Issuer string `json:"issuer"` + NotBefore int64 `json:"not_before"` + NotAfter int64 `json:"not_after"` + DNSNames []string `json:"dns_names"` + EmailAddresses []string `json:"email_addresses"` +} + +func GetCertInfo(w http.ResponseWriter, r *http.Request) { + autocert := config.GetInstance().AutoCertProvider() + if autocert == nil { + http.Error(w, "autocert is not enabled", http.StatusNotFound) + return + } + + cert, err := autocert.GetCert(nil) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + certInfo := CertInfo{ + Subject: cert.Leaf.Subject.CommonName, + Issuer: cert.Leaf.Issuer.CommonName, + NotBefore: cert.Leaf.NotBefore.Unix(), + NotAfter: cert.Leaf.NotAfter.Unix(), + DNSNames: cert.Leaf.DNSNames, + EmailAddresses: cert.Leaf.EmailAddresses, + } + json.NewEncoder(w).Encode(&certInfo) +} diff --git a/internal/api/v1/certapi/renew.go b/internal/api/v1/certapi/renew.go new file mode 100644 index 0000000..bb993f9 --- /dev/null +++ b/internal/api/v1/certapi/renew.go @@ -0,0 +1,56 @@ +package certapi + +import ( + "net/http" + + config "github.com/yusing/go-proxy/internal/config/types" + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/logging/memlogger" + "github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket" +) + +func RenewCert(w http.ResponseWriter, r *http.Request) { + autocert := config.GetInstance().AutoCertProvider() + if autocert == nil { + http.Error(w, "autocert is not enabled", http.StatusNotFound) + return + } + + conn, err := gpwebsocket.Initiate(w, r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + //nolint:errcheck + defer conn.CloseNow() + + logs, cancel := memlogger.Events() + defer cancel() + + done := make(chan struct{}) + + go func() { + defer close(done) + err = autocert.ObtainCert() + if err != nil { + gperr.LogError("failed to obtain cert", err) + gpwebsocket.WriteText(r, conn, err.Error()) + } else { + logging.Info().Msg("cert obtained successfully") + } + }() + for { + select { + case l := <-logs: + if err != nil { + return + } + if !gpwebsocket.WriteText(r, conn, string(l)) { + return + } + case <-done: + return + } + } +} diff --git a/internal/api/v1/file.go b/internal/api/v1/config_file.go similarity index 100% rename from internal/api/v1/file.go rename to internal/api/v1/config_file.go diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 19de2ba..09a98be 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -51,7 +51,7 @@ func (cfg *AutocertConfig) Validate() gperr.Error { } b := gperr.NewBuilder("autocert errors") - if cfg.Provider != ProviderLocal { + if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo { if len(cfg.Domains) == 0 { b.Add(ErrMissingDomain) } @@ -101,7 +101,7 @@ func (cfg *AutocertConfig) GetProvider() (*Provider, gperr.Error) { var privKey *ecdsa.PrivateKey var err error - if cfg.Provider != ProviderLocal { + if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo { if privKey, err = cfg.loadACMEKey(); err != nil { logging.Info().Err(err).Msg("load ACME private key failed") logging.Info().Msg("generate new ACME private key") diff --git a/internal/autocert/constants.go b/internal/autocert/constants.go index f50a109..adbbc9c 100644 --- a/internal/autocert/constants.go +++ b/internal/autocert/constants.go @@ -20,6 +20,7 @@ const ( ProviderClouddns = "clouddns" ProviderDuckdns = "duckdns" ProviderOVH = "ovh" + ProviderPseudo = "pseudo" // for testing ) var providersGenMap = map[string]ProviderGenerator{ @@ -28,4 +29,5 @@ var providersGenMap = map[string]ProviderGenerator{ ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig), ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig), ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig), + ProviderPseudo: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig), } diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 9948d32..654d80a 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -9,6 +9,7 @@ import ( "path" "reflect" "sort" + "sync" "time" "github.com/go-acme/lego/v4/certificate" @@ -32,6 +33,8 @@ type ( legoCert *certificate.Resource tlsCert *tls.Certificate certExpiries CertExpiries + + obtainMu sync.Mutex } ProviderGenerator func(ProviderOpt) (challenge.Provider, gperr.Error) @@ -68,6 +71,17 @@ func (p *Provider) ObtainCert() error { return nil } + if p.cfg.Provider == ProviderPseudo { + t := time.NewTicker(1000 * time.Millisecond) + defer t.Stop() + logging.Info().Msg("init client for pseudo provider") + <-t.C + logging.Info().Msg("registering acme for pseudo provider") + <-t.C + logging.Info().Msg("obtained cert for pseudo provider") + return nil + } + if p.client == nil { if err := p.initClient(); err != nil { return err @@ -150,7 +164,7 @@ func (p *Provider) ShouldRenewOn() time.Time { } func (p *Provider) ScheduleRenewal(parent task.Parent) { - if p.GetName() == ProviderLocal { + if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo { return } go func() { diff --git a/internal/config/config.go b/internal/config/config.go index 28fd5e6..5cc3bef 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -51,10 +51,6 @@ You may run "ls-config" to show or dump the current config.` var Validate = config.Validate -func GetInstance() *Config { - return config.GetInstance().(*Config) -} - func newConfig() *Config { return &Config{ value: config.DefaultConfig(), @@ -75,7 +71,7 @@ func Load() (*Config, gperr.Error) { } func MatchDomains() []string { - return GetInstance().Value().MatchDomains + return config.GetInstance().Value().MatchDomains } func WatchChanges() { @@ -123,7 +119,7 @@ func Reload() gperr.Error { // cancel all current subtasks -> wait // -> replace config -> start new subtasks - GetInstance().Task().Finish("config changed") + config.GetInstance().(*Config).Task().Finish("config changed") newCfg.Start(StartAllServers) config.SetInstance(newCfg) return nil diff --git a/internal/config/types/config.go b/internal/config/types/config.go index 60c766c..76d5ae9 100644 --- a/internal/config/types/config.go +++ b/internal/config/types/config.go @@ -43,6 +43,7 @@ type ( GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool) AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) ListAgents() []*agent.AgentConfig + AutoCertProvider() *autocert.Provider } ) diff --git a/internal/config/types/config_test.go b/internal/config/types/config_test.go new file mode 100644 index 0000000..9f34456 --- /dev/null +++ b/internal/config/types/config_test.go @@ -0,0 +1,42 @@ +package types + +import ( + "testing" + + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/utils" + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +func TestValidateConfig(t *testing.T) { + cases := []struct { + name string + data []byte + want gperr.Error + }{ + { + name: "valid config", + data: []byte(` +autocert: + provider: local +`), + want: nil, + }, + { + name: "unknown field", + data: []byte(` +autocert: + provider: local + unknown: true +`), + want: utils.ErrUnknownField, + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + got := Validate(c.data) + ExpectError(t, c.want, got) + }) + } +} diff --git a/internal/logging/memlogger/mem_logger.go b/internal/logging/memlogger/mem_logger.go index 6c3cf69..a75fad5 100644 --- a/internal/logging/memlogger/mem_logger.go +++ b/internal/logging/memlogger/mem_logger.go @@ -90,34 +90,35 @@ func (m *memLogger) truncateIfNeeded(n int) { } func (m *memLogger) notifyWS(pos, n int) { - if m.connChans.Size() > 0 { - timeout := time.NewTimer(2 * time.Second) - defer timeout.Stop() + if m.connChans.Size() == 0 && m.listeners.Size() == 0 { + return + } - m.notifyLock.RLock() - defer m.notifyLock.RUnlock() - m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool { + timeout := time.NewTimer(3 * time.Second) + defer timeout.Stop() + + m.notifyLock.RLock() + defer m.notifyLock.RUnlock() + + m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool { + select { + case ch <- &logEntryRange{pos, pos + n}: + return true + case <-timeout.C: + return false + } + }) + + if m.listeners.Size() > 0 { + msg := m.Buffer.Bytes()[pos : pos+n] + m.listeners.Range(func(ch chan []byte, _ struct{}) bool { select { - case ch <- &logEntryRange{pos, pos + n}: - return true case <-timeout.C: - logging.Warn().Msg("mem logger: timeout logging to channel") return false + case ch <- msg: + return true } }) - if m.listeners.Size() > 0 { - msg := m.Buffer.Bytes()[pos : pos+n] - m.listeners.Range(func(ch chan []byte, _ struct{}) bool { - select { - case <-timeout.C: - logging.Warn().Msg("mem logger: timeout logging to channel") - return false - case ch <- msg: - return true - } - }) - } - return } }