From e4f6994dfc2b8122cdf1affe3e524fbb852c4906 Mon Sep 17 00:00:00 2001 From: yusing Date: Fri, 28 Mar 2025 06:04:09 +0800 Subject: [PATCH] autocert: refactor and add pseudo provider for testing --- internal/autocert/config.go | 26 ++++++++--------- internal/autocert/constants.go | 2 ++ internal/autocert/provider.go | 53 ++++++++++++++++++++++------------ internal/autocert/setup.go | 6 ++-- 4 files changed, 52 insertions(+), 35 deletions(-) diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 3a32f73..09a98be 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -10,7 +10,7 @@ import ( "github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/lego" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" @@ -30,17 +30,17 @@ type ( ) var ( - ErrMissingDomain = E.New("missing field 'domains'") - ErrMissingEmail = E.New("missing field 'email'") - ErrMissingProvider = E.New("missing field 'provider'") - ErrInvalidDomain = E.New("invalid domain") - ErrUnknownProvider = E.New("unknown provider") + ErrMissingDomain = gperr.New("missing field 'domains'") + ErrMissingEmail = gperr.New("missing field 'email'") + ErrMissingProvider = gperr.New("missing field 'provider'") + ErrInvalidDomain = gperr.New("invalid domain") + ErrUnknownProvider = gperr.New("unknown provider") ) var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`) // Validate implements the utils.CustomValidator interface. -func (cfg *AutocertConfig) Validate() E.Error { +func (cfg *AutocertConfig) Validate() gperr.Error { if cfg == nil { return nil } @@ -50,8 +50,8 @@ func (cfg *AutocertConfig) Validate() E.Error { return nil } - b := E.NewBuilder("autocert errors") - if cfg.Provider != ProviderLocal { + b := gperr.NewBuilder("autocert errors") + if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo { if len(cfg.Domains) == 0 { b.Add(ErrMissingDomain) } @@ -79,7 +79,7 @@ func (cfg *AutocertConfig) Validate() E.Error { return b.Error() } -func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) { +func (cfg *AutocertConfig) GetProvider() (*Provider, gperr.Error) { if cfg == nil { cfg = new(AutocertConfig) } @@ -101,16 +101,16 @@ func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) { var privKey *ecdsa.PrivateKey var err error - if cfg.Provider != ProviderLocal { + if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo { if privKey, err = cfg.loadACMEKey(); err != nil { logging.Info().Err(err).Msg("load ACME private key failed") logging.Info().Msg("generate new ACME private key") privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { - return nil, E.New("generate ACME private key").With(err) + return nil, gperr.New("generate ACME private key").With(err) } if err = cfg.saveACMEKey(privKey); err != nil { - return nil, E.New("save ACME private key").With(err) + return nil, gperr.New("save ACME private key").With(err) } } } diff --git a/internal/autocert/constants.go b/internal/autocert/constants.go index d1d98a4..0ff5fb6 100644 --- a/internal/autocert/constants.go +++ b/internal/autocert/constants.go @@ -21,6 +21,7 @@ const ( ProviderClouddns = "clouddns" ProviderDuckdns = "duckdns" ProviderOVH = "ovh" + ProviderPseudo = "pseudo" // for testing ProviderPorkbun = "porkbun" ) @@ -30,5 +31,6 @@ var providersGenMap = map[string]ProviderGenerator{ ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig), ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig), ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig), + ProviderPseudo: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig), ProviderPorkbun: providerGenerator(porkbun.NewDefaultConfig, porkbun.NewDNSProviderConfig), } diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index eaad175..654d80a 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -4,17 +4,19 @@ import ( "crypto/tls" "crypto/x509" "errors" + "fmt" "os" "path" "reflect" "sort" + "sync" "time" "github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/registration" - E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/task" U "github.com/yusing/go-proxy/internal/utils" @@ -31,8 +33,10 @@ type ( legoCert *certificate.Resource tlsCert *tls.Certificate certExpiries CertExpiries + + obtainMu sync.Mutex } - ProviderGenerator func(ProviderOpt) (challenge.Provider, E.Error) + ProviderGenerator func(ProviderOpt) (challenge.Provider, gperr.Error) CertExpiries map[string]time.Time ) @@ -62,11 +66,22 @@ func (p *Provider) GetExpiries() CertExpiries { return p.certExpiries } -func (p *Provider) ObtainCert() E.Error { +func (p *Provider) ObtainCert() error { if p.cfg.Provider == ProviderLocal { return nil } + if p.cfg.Provider == ProviderPseudo { + t := time.NewTicker(1000 * time.Millisecond) + defer t.Stop() + logging.Info().Msg("init client for pseudo provider") + <-t.C + logging.Info().Msg("registering acme for pseudo provider") + <-t.C + logging.Info().Msg("obtained cert for pseudo provider") + return nil + } + if p.client == nil { if err := p.initClient(); err != nil { return err @@ -75,7 +90,7 @@ func (p *Provider) ObtainCert() E.Error { if p.user.Registration == nil { if err := p.registerACME(); err != nil { - return E.From(err) + return err } } @@ -100,22 +115,22 @@ func (p *Provider) ObtainCert() E.Error { Bundle: true, }) if err != nil { - return E.From(err) + return err } } if err = p.saveCert(cert); err != nil { - return E.From(err) + return err } tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey) if err != nil { - return E.From(err) + return err } expiries, err := getCertExpiries(&tlsCert) if err != nil { - return E.From(err) + return err } p.tlsCert = &tlsCert p.certExpiries = expiries @@ -123,14 +138,14 @@ func (p *Provider) ObtainCert() E.Error { return nil } -func (p *Provider) LoadCert() E.Error { +func (p *Provider) LoadCert() error { cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath) if err != nil { - return E.Errorf("load SSL certificate: %w", err) + return fmt.Errorf("load SSL certificate: %w", err) } expiries, err := getCertExpiries(&cert) if err != nil { - return E.Errorf("parse SSL certificate: %w", err) + return fmt.Errorf("parse SSL certificate: %w", err) } p.tlsCert = &cert p.certExpiries = expiries @@ -149,7 +164,7 @@ func (p *Provider) ShouldRenewOn() time.Time { } func (p *Provider) ScheduleRenewal(parent task.Parent) { - if p.GetName() == ProviderLocal { + if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo { return } go func() { @@ -171,7 +186,7 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) { continue } if err := p.renewIfNeeded(); err != nil { - E.LogWarn("cert renew failed", err) + gperr.LogWarn("cert renew failed", err) lastErrOn = time.Now() continue } @@ -184,10 +199,10 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) { }() } -func (p *Provider) initClient() E.Error { +func (p *Provider) initClient() error { legoClient, err := lego.NewClient(p.legoCfg) if err != nil { - return E.From(err) + return err } generator := providersGenMap[p.cfg.Provider] @@ -198,7 +213,7 @@ func (p *Provider) initClient() E.Error { err = legoClient.Challenge.SetDNS01Provider(legoProvider) if err != nil { - return E.From(err) + return err } p.client = legoClient @@ -273,7 +288,7 @@ func (p *Provider) certState() CertState { return CertStateValid } -func (p *Provider) renewIfNeeded() E.Error { +func (p *Provider) renewIfNeeded() error { if p.cfg.Provider == ProviderLocal { return nil } @@ -312,13 +327,13 @@ func providerGenerator[CT any, PT challenge.Provider]( defaultCfg func() *CT, newProvider func(*CT) (PT, error), ) ProviderGenerator { - return func(opt ProviderOpt) (challenge.Provider, E.Error) { + return func(opt ProviderOpt) (challenge.Provider, gperr.Error) { cfg := defaultCfg() err := U.Deserialize(opt, &cfg) if err != nil { return nil, err } p, pErr := newProvider(cfg) - return p, E.From(pErr) + return p, gperr.Wrap(pErr) } } diff --git a/internal/autocert/setup.go b/internal/autocert/setup.go index 82e58ad..b436bed 100644 --- a/internal/autocert/setup.go +++ b/internal/autocert/setup.go @@ -1,16 +1,16 @@ package autocert import ( + "errors" "os" - E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils/strutils" ) -func (p *Provider) Setup() (err E.Error) { +func (p *Provider) Setup() (err error) { if err = p.LoadCert(); err != nil { - if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist + if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist return err } logging.Debug().Msg("obtaining cert due to error loading cert")