add cert info and renewal api

This commit is contained in:
yusing 2025-02-15 21:50:34 +08:00
parent 7129e2cc9d
commit 16b046bd44
15 changed files with 201 additions and 47 deletions

View file

@ -90,7 +90,7 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
caCertPool := x509.NewCertPool() caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(ca) ok := caCertPool.AppendCertsFromPEM(ca)
if !ok { if !ok {
return gperr.New("invalid CA certificate") return gperr.New("invalid ca certificate")
} }
cfg.tlsConfig = &tls.Config{ cfg.tlsConfig = &tls.Config{
@ -128,21 +128,18 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
return nil return nil
} }
func (cfg *AgentConfig) Start(parent task.Parent) error { func (cfg *AgentConfig) Start(parent task.Parent) gperr.Error {
certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr)) certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr))
if err != nil { if err != nil {
if os.IsNotExist(err) { return gperr.Wrap(err, "failed to read agent certs")
return gperr.Errorf("agents certs not found, did you run `godoxy new-agent %s ...`?", cfg.Addr)
}
return gperr.Wrap(err)
} }
ca, crt, key, err := certs.ExtractCert(certData) ca, crt, key, err := certs.ExtractCert(certData)
if err != nil { if err != nil {
return gperr.Wrap(err) return gperr.Wrap(err, "failed to extract agent certs")
} }
return cfg.StartWithCerts(parent, ca, crt, key) return gperr.Wrap(cfg.StartWithCerts(parent, ca, crt, key))
} }
func (cfg *AgentConfig) NewHTTPClient() *http.Client { func (cfg *AgentConfig) NewHTTPClient() *http.Client {

View file

@ -6,7 +6,6 @@ import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"log"
"net" "net"
"net/http" "net/http"
"time" "time"
@ -45,7 +44,6 @@ func StartAgentServer(parent task.Parent, opt Options) {
agentServer := &http.Server{ agentServer := &http.Server{
Handler: handler.NewAgentHandler(), Handler: handler.NewAgentHandler(),
TLSConfig: tlsConfig, TLSConfig: tlsConfig,
ErrorLog: log.New(logger, "", 0),
} }
go func() { go func() {

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/yusing/go-proxy
go 1.24.0 go 1.24.0
require ( require (
github.com/PuerkitoBio/goquery v1.10.1 github.com/PuerkitoBio/goquery v1.10.2
github.com/coder/websocket v1.8.12 github.com/coder/websocket v1.8.12
github.com/coreos/go-oidc/v3 v3.12.0 github.com/coreos/go-oidc/v3 v3.12.0
github.com/docker/cli v27.5.1+incompatible github.com/docker/cli v27.5.1+incompatible

4
go.sum
View file

@ -2,8 +2,8 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/PuerkitoBio/goquery v1.10.1 h1:Y8JGYUkXWTGRB6Ars3+j3kN0xg1YqqlwvdTV8WTFQcU= github.com/PuerkitoBio/goquery v1.10.2 h1:7fh2BdHcG6VFZsK7toXBT/Bh1z5Wmy8Q9MV9HqT2AM8=
github.com/PuerkitoBio/goquery v1.10.1/go.mod h1:IYiHrOMps66ag56LEH7QYDDupKXyo5A8qrjIx3ZtujY= github.com/PuerkitoBio/goquery v1.10.2/go.mod h1:0guWGjcLu9AYC7C1GHnpysHy056u9aEkUHwhdnePMCU=
github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM= github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM=
github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA= github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=

View file

@ -7,6 +7,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
v1 "github.com/yusing/go-proxy/internal/api/v1" v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/auth"
"github.com/yusing/go-proxy/internal/api/v1/certapi"
"github.com/yusing/go-proxy/internal/api/v1/favicon" "github.com/yusing/go-proxy/internal/api/v1/favicon"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
@ -54,9 +55,12 @@ func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...b
if len(requireAuth) > 0 && requireAuth[0] { if len(requireAuth) > 0 && requireAuth[0] {
handler = auth.RequireAuth(handler) handler = auth.RequireAuth(handler)
} }
if methods == "" {
for _, m := range strutils.CommaSeperatedList(methods) { mux.ServeMux.HandleFunc(endpoint, handler)
mux.ServeMux.HandleFunc(m+" "+endpoint, handler) } else {
for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
}
} }
} }
@ -82,6 +86,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
mux.HandleFunc("POST", "/v1/agents/add", v1.AddAgent, true) mux.HandleFunc("POST", "/v1/agents/add", v1.AddAgent, true)
mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true) mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true)
mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true) mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true)
mux.HandleFunc("GET", "/v1/cert/info", certapi.GetCertInfo, true)
mux.HandleFunc("", "/v1/cert/renew", certapi.RenewCert, true)
if common.PrometheusEnabled { if common.PrometheusEnabled {
mux.Handle("GET /v1/metrics", promhttp.Handler()) mux.Handle("GET /v1/metrics", promhttp.Handler())

View file

@ -0,0 +1,41 @@
package certapi
import (
"encoding/json"
"net/http"
config "github.com/yusing/go-proxy/internal/config/types"
)
type CertInfo struct {
Subject string `json:"subject"`
Issuer string `json:"issuer"`
NotBefore int64 `json:"not_before"`
NotAfter int64 `json:"not_after"`
DNSNames []string `json:"dns_names"`
EmailAddresses []string `json:"email_addresses"`
}
func GetCertInfo(w http.ResponseWriter, r *http.Request) {
autocert := config.GetInstance().AutoCertProvider()
if autocert == nil {
http.Error(w, "autocert is not enabled", http.StatusNotFound)
return
}
cert, err := autocert.GetCert(nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
certInfo := CertInfo{
Subject: cert.Leaf.Subject.CommonName,
Issuer: cert.Leaf.Issuer.CommonName,
NotBefore: cert.Leaf.NotBefore.Unix(),
NotAfter: cert.Leaf.NotAfter.Unix(),
DNSNames: cert.Leaf.DNSNames,
EmailAddresses: cert.Leaf.EmailAddresses,
}
json.NewEncoder(w).Encode(&certInfo)
}

View file

@ -0,0 +1,56 @@
package certapi
import (
"net/http"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
)
func RenewCert(w http.ResponseWriter, r *http.Request) {
autocert := config.GetInstance().AutoCertProvider()
if autocert == nil {
http.Error(w, "autocert is not enabled", http.StatusNotFound)
return
}
conn, err := gpwebsocket.Initiate(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
//nolint:errcheck
defer conn.CloseNow()
logs, cancel := memlogger.Events()
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
err = autocert.ObtainCert()
if err != nil {
gperr.LogError("failed to obtain cert", err)
gpwebsocket.WriteText(r, conn, err.Error())
} else {
logging.Info().Msg("cert obtained successfully")
}
}()
for {
select {
case l := <-logs:
if err != nil {
return
}
if !gpwebsocket.WriteText(r, conn, string(l)) {
return
}
case <-done:
return
}
}
}

View file

@ -51,7 +51,7 @@ func (cfg *AutocertConfig) Validate() gperr.Error {
} }
b := gperr.NewBuilder("autocert errors") b := gperr.NewBuilder("autocert errors")
if cfg.Provider != ProviderLocal { if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if len(cfg.Domains) == 0 { if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain) b.Add(ErrMissingDomain)
} }
@ -101,7 +101,7 @@ func (cfg *AutocertConfig) GetProvider() (*Provider, gperr.Error) {
var privKey *ecdsa.PrivateKey var privKey *ecdsa.PrivateKey
var err error var err error
if cfg.Provider != ProviderLocal { if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if privKey, err = cfg.loadACMEKey(); err != nil { if privKey, err = cfg.loadACMEKey(); err != nil {
logging.Info().Err(err).Msg("load ACME private key failed") logging.Info().Err(err).Msg("load ACME private key failed")
logging.Info().Msg("generate new ACME private key") logging.Info().Msg("generate new ACME private key")

View file

@ -20,6 +20,7 @@ const (
ProviderClouddns = "clouddns" ProviderClouddns = "clouddns"
ProviderDuckdns = "duckdns" ProviderDuckdns = "duckdns"
ProviderOVH = "ovh" ProviderOVH = "ovh"
ProviderPseudo = "pseudo" // for testing
) )
var providersGenMap = map[string]ProviderGenerator{ var providersGenMap = map[string]ProviderGenerator{
@ -28,4 +29,5 @@ var providersGenMap = map[string]ProviderGenerator{
ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig), ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig), ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig), ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
ProviderPseudo: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig),
} }

View file

@ -9,6 +9,7 @@ import (
"path" "path"
"reflect" "reflect"
"sort" "sort"
"sync"
"time" "time"
"github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/certificate"
@ -32,6 +33,8 @@ type (
legoCert *certificate.Resource legoCert *certificate.Resource
tlsCert *tls.Certificate tlsCert *tls.Certificate
certExpiries CertExpiries certExpiries CertExpiries
obtainMu sync.Mutex
} }
ProviderGenerator func(ProviderOpt) (challenge.Provider, gperr.Error) ProviderGenerator func(ProviderOpt) (challenge.Provider, gperr.Error)
@ -68,6 +71,17 @@ func (p *Provider) ObtainCert() error {
return nil return nil
} }
if p.cfg.Provider == ProviderPseudo {
t := time.NewTicker(1000 * time.Millisecond)
defer t.Stop()
logging.Info().Msg("init client for pseudo provider")
<-t.C
logging.Info().Msg("registering acme for pseudo provider")
<-t.C
logging.Info().Msg("obtained cert for pseudo provider")
return nil
}
if p.client == nil { if p.client == nil {
if err := p.initClient(); err != nil { if err := p.initClient(); err != nil {
return err return err
@ -150,7 +164,7 @@ func (p *Provider) ShouldRenewOn() time.Time {
} }
func (p *Provider) ScheduleRenewal(parent task.Parent) { func (p *Provider) ScheduleRenewal(parent task.Parent) {
if p.GetName() == ProviderLocal { if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
return return
} }
go func() { go func() {

View file

@ -51,10 +51,6 @@ You may run "ls-config" to show or dump the current config.`
var Validate = config.Validate var Validate = config.Validate
func GetInstance() *Config {
return config.GetInstance().(*Config)
}
func newConfig() *Config { func newConfig() *Config {
return &Config{ return &Config{
value: config.DefaultConfig(), value: config.DefaultConfig(),
@ -75,7 +71,7 @@ func Load() (*Config, gperr.Error) {
} }
func MatchDomains() []string { func MatchDomains() []string {
return GetInstance().Value().MatchDomains return config.GetInstance().Value().MatchDomains
} }
func WatchChanges() { func WatchChanges() {
@ -123,7 +119,7 @@ func Reload() gperr.Error {
// cancel all current subtasks -> wait // cancel all current subtasks -> wait
// -> replace config -> start new subtasks // -> replace config -> start new subtasks
GetInstance().Task().Finish("config changed") config.GetInstance().(*Config).Task().Finish("config changed")
newCfg.Start(StartAllServers) newCfg.Start(StartAllServers)
config.SetInstance(newCfg) config.SetInstance(newCfg)
return nil return nil

View file

@ -43,6 +43,7 @@ type (
GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool) GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool)
AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error) AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error)
ListAgents() []*agent.AgentConfig ListAgents() []*agent.AgentConfig
AutoCertProvider() *autocert.Provider
} }
) )

View file

@ -0,0 +1,42 @@
package types
import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestValidateConfig(t *testing.T) {
cases := []struct {
name string
data []byte
want gperr.Error
}{
{
name: "valid config",
data: []byte(`
autocert:
provider: local
`),
want: nil,
},
{
name: "unknown field",
data: []byte(`
autocert:
provider: local
unknown: true
`),
want: utils.ErrUnknownField,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := Validate(c.data)
ExpectError(t, c.want, got)
})
}
}

View file

@ -90,34 +90,35 @@ func (m *memLogger) truncateIfNeeded(n int) {
} }
func (m *memLogger) notifyWS(pos, n int) { func (m *memLogger) notifyWS(pos, n int) {
if m.connChans.Size() > 0 { if m.connChans.Size() == 0 && m.listeners.Size() == 0 {
timeout := time.NewTimer(2 * time.Second) return
defer timeout.Stop() }
m.notifyLock.RLock() timeout := time.NewTimer(3 * time.Second)
defer m.notifyLock.RUnlock() defer timeout.Stop()
m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
m.notifyLock.RLock()
defer m.notifyLock.RUnlock()
m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
select {
case ch <- &logEntryRange{pos, pos + n}:
return true
case <-timeout.C:
return false
}
})
if m.listeners.Size() > 0 {
msg := m.Buffer.Bytes()[pos : pos+n]
m.listeners.Range(func(ch chan []byte, _ struct{}) bool {
select { select {
case ch <- &logEntryRange{pos, pos + n}:
return true
case <-timeout.C: case <-timeout.C:
logging.Warn().Msg("mem logger: timeout logging to channel")
return false return false
case ch <- msg:
return true
} }
}) })
if m.listeners.Size() > 0 {
msg := m.Buffer.Bytes()[pos : pos+n]
m.listeners.Range(func(ch chan []byte, _ struct{}) bool {
select {
case <-timeout.C:
logging.Warn().Msg("mem logger: timeout logging to channel")
return false
case ch <- msg:
return true
}
})
}
return
} }
} }