autocert: refactor and add pseudo provider for testing

This commit is contained in:
yusing 2025-03-28 06:04:09 +08:00
parent 827a27911c
commit e4f6994dfc
4 changed files with 52 additions and 35 deletions

View file

@ -10,7 +10,7 @@ import (
"github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certcrypto"
"github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/lego"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
@ -30,17 +30,17 @@ type (
) )
var ( var (
ErrMissingDomain = E.New("missing field 'domains'") ErrMissingDomain = gperr.New("missing field 'domains'")
ErrMissingEmail = E.New("missing field 'email'") ErrMissingEmail = gperr.New("missing field 'email'")
ErrMissingProvider = E.New("missing field 'provider'") ErrMissingProvider = gperr.New("missing field 'provider'")
ErrInvalidDomain = E.New("invalid domain") ErrInvalidDomain = gperr.New("invalid domain")
ErrUnknownProvider = E.New("unknown provider") ErrUnknownProvider = gperr.New("unknown provider")
) )
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`) var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
// Validate implements the utils.CustomValidator interface. // Validate implements the utils.CustomValidator interface.
func (cfg *AutocertConfig) Validate() E.Error { func (cfg *AutocertConfig) Validate() gperr.Error {
if cfg == nil { if cfg == nil {
return nil return nil
} }
@ -50,8 +50,8 @@ func (cfg *AutocertConfig) Validate() E.Error {
return nil return nil
} }
b := E.NewBuilder("autocert errors") b := gperr.NewBuilder("autocert errors")
if cfg.Provider != ProviderLocal { if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if len(cfg.Domains) == 0 { if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain) b.Add(ErrMissingDomain)
} }
@ -79,7 +79,7 @@ func (cfg *AutocertConfig) Validate() E.Error {
return b.Error() return b.Error()
} }
func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) { func (cfg *AutocertConfig) GetProvider() (*Provider, gperr.Error) {
if cfg == nil { if cfg == nil {
cfg = new(AutocertConfig) cfg = new(AutocertConfig)
} }
@ -101,16 +101,16 @@ func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) {
var privKey *ecdsa.PrivateKey var privKey *ecdsa.PrivateKey
var err error var err error
if cfg.Provider != ProviderLocal { if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if privKey, err = cfg.loadACMEKey(); err != nil { if privKey, err = cfg.loadACMEKey(); err != nil {
logging.Info().Err(err).Msg("load ACME private key failed") logging.Info().Err(err).Msg("load ACME private key failed")
logging.Info().Msg("generate new ACME private key") logging.Info().Msg("generate new ACME private key")
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil { if err != nil {
return nil, E.New("generate ACME private key").With(err) return nil, gperr.New("generate ACME private key").With(err)
} }
if err = cfg.saveACMEKey(privKey); err != nil { if err = cfg.saveACMEKey(privKey); err != nil {
return nil, E.New("save ACME private key").With(err) return nil, gperr.New("save ACME private key").With(err)
} }
} }
} }

View file

@ -21,6 +21,7 @@ const (
ProviderClouddns = "clouddns" ProviderClouddns = "clouddns"
ProviderDuckdns = "duckdns" ProviderDuckdns = "duckdns"
ProviderOVH = "ovh" ProviderOVH = "ovh"
ProviderPseudo = "pseudo" // for testing
ProviderPorkbun = "porkbun" ProviderPorkbun = "porkbun"
) )
@ -30,5 +31,6 @@ var providersGenMap = map[string]ProviderGenerator{
ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig), ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig), ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig), ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
ProviderPseudo: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig),
ProviderPorkbun: providerGenerator(porkbun.NewDefaultConfig, porkbun.NewDNSProviderConfig), ProviderPorkbun: providerGenerator(porkbun.NewDefaultConfig, porkbun.NewDNSProviderConfig),
} }

View file

@ -4,17 +4,19 @@ import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors" "errors"
"fmt"
"os" "os"
"path" "path"
"reflect" "reflect"
"sort" "sort"
"sync"
"time" "time"
"github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/certificate"
"github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/challenge"
"github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/lego"
"github.com/go-acme/lego/v4/registration" "github.com/go-acme/lego/v4/registration"
E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task" "github.com/yusing/go-proxy/internal/task"
U "github.com/yusing/go-proxy/internal/utils" U "github.com/yusing/go-proxy/internal/utils"
@ -31,8 +33,10 @@ type (
legoCert *certificate.Resource legoCert *certificate.Resource
tlsCert *tls.Certificate tlsCert *tls.Certificate
certExpiries CertExpiries certExpiries CertExpiries
obtainMu sync.Mutex
} }
ProviderGenerator func(ProviderOpt) (challenge.Provider, E.Error) ProviderGenerator func(ProviderOpt) (challenge.Provider, gperr.Error)
CertExpiries map[string]time.Time CertExpiries map[string]time.Time
) )
@ -62,11 +66,22 @@ func (p *Provider) GetExpiries() CertExpiries {
return p.certExpiries return p.certExpiries
} }
func (p *Provider) ObtainCert() E.Error { func (p *Provider) ObtainCert() error {
if p.cfg.Provider == ProviderLocal { if p.cfg.Provider == ProviderLocal {
return nil return nil
} }
if p.cfg.Provider == ProviderPseudo {
t := time.NewTicker(1000 * time.Millisecond)
defer t.Stop()
logging.Info().Msg("init client for pseudo provider")
<-t.C
logging.Info().Msg("registering acme for pseudo provider")
<-t.C
logging.Info().Msg("obtained cert for pseudo provider")
return nil
}
if p.client == nil { if p.client == nil {
if err := p.initClient(); err != nil { if err := p.initClient(); err != nil {
return err return err
@ -75,7 +90,7 @@ func (p *Provider) ObtainCert() E.Error {
if p.user.Registration == nil { if p.user.Registration == nil {
if err := p.registerACME(); err != nil { if err := p.registerACME(); err != nil {
return E.From(err) return err
} }
} }
@ -100,22 +115,22 @@ func (p *Provider) ObtainCert() E.Error {
Bundle: true, Bundle: true,
}) })
if err != nil { if err != nil {
return E.From(err) return err
} }
} }
if err = p.saveCert(cert); err != nil { if err = p.saveCert(cert); err != nil {
return E.From(err) return err
} }
tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey) tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
if err != nil { if err != nil {
return E.From(err) return err
} }
expiries, err := getCertExpiries(&tlsCert) expiries, err := getCertExpiries(&tlsCert)
if err != nil { if err != nil {
return E.From(err) return err
} }
p.tlsCert = &tlsCert p.tlsCert = &tlsCert
p.certExpiries = expiries p.certExpiries = expiries
@ -123,14 +138,14 @@ func (p *Provider) ObtainCert() E.Error {
return nil return nil
} }
func (p *Provider) LoadCert() E.Error { func (p *Provider) LoadCert() error {
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath) cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
if err != nil { if err != nil {
return E.Errorf("load SSL certificate: %w", err) return fmt.Errorf("load SSL certificate: %w", err)
} }
expiries, err := getCertExpiries(&cert) expiries, err := getCertExpiries(&cert)
if err != nil { if err != nil {
return E.Errorf("parse SSL certificate: %w", err) return fmt.Errorf("parse SSL certificate: %w", err)
} }
p.tlsCert = &cert p.tlsCert = &cert
p.certExpiries = expiries p.certExpiries = expiries
@ -149,7 +164,7 @@ func (p *Provider) ShouldRenewOn() time.Time {
} }
func (p *Provider) ScheduleRenewal(parent task.Parent) { func (p *Provider) ScheduleRenewal(parent task.Parent) {
if p.GetName() == ProviderLocal { if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
return return
} }
go func() { go func() {
@ -171,7 +186,7 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) {
continue continue
} }
if err := p.renewIfNeeded(); err != nil { if err := p.renewIfNeeded(); err != nil {
E.LogWarn("cert renew failed", err) gperr.LogWarn("cert renew failed", err)
lastErrOn = time.Now() lastErrOn = time.Now()
continue continue
} }
@ -184,10 +199,10 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) {
}() }()
} }
func (p *Provider) initClient() E.Error { func (p *Provider) initClient() error {
legoClient, err := lego.NewClient(p.legoCfg) legoClient, err := lego.NewClient(p.legoCfg)
if err != nil { if err != nil {
return E.From(err) return err
} }
generator := providersGenMap[p.cfg.Provider] generator := providersGenMap[p.cfg.Provider]
@ -198,7 +213,7 @@ func (p *Provider) initClient() E.Error {
err = legoClient.Challenge.SetDNS01Provider(legoProvider) err = legoClient.Challenge.SetDNS01Provider(legoProvider)
if err != nil { if err != nil {
return E.From(err) return err
} }
p.client = legoClient p.client = legoClient
@ -273,7 +288,7 @@ func (p *Provider) certState() CertState {
return CertStateValid return CertStateValid
} }
func (p *Provider) renewIfNeeded() E.Error { func (p *Provider) renewIfNeeded() error {
if p.cfg.Provider == ProviderLocal { if p.cfg.Provider == ProviderLocal {
return nil return nil
} }
@ -312,13 +327,13 @@ func providerGenerator[CT any, PT challenge.Provider](
defaultCfg func() *CT, defaultCfg func() *CT,
newProvider func(*CT) (PT, error), newProvider func(*CT) (PT, error),
) ProviderGenerator { ) ProviderGenerator {
return func(opt ProviderOpt) (challenge.Provider, E.Error) { return func(opt ProviderOpt) (challenge.Provider, gperr.Error) {
cfg := defaultCfg() cfg := defaultCfg()
err := U.Deserialize(opt, &cfg) err := U.Deserialize(opt, &cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
p, pErr := newProvider(cfg) p, pErr := newProvider(cfg)
return p, E.From(pErr) return p, gperr.Wrap(pErr)
} }
} }

View file

@ -1,16 +1,16 @@
package autocert package autocert
import ( import (
"errors"
"os" "os"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
func (p *Provider) Setup() (err E.Error) { func (p *Provider) Setup() (err error) {
if err = p.LoadCert(); err != nil { if err = p.LoadCert(); err != nil {
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
return err return err
} }
logging.Debug().Msg("obtaining cert due to error loading cert") logging.Debug().Msg("obtaining cert due to error loading cert")