mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-09 13:02:33 +02:00
autocert: refactor and add pseudo provider for testing
This commit is contained in:
parent
827a27911c
commit
e4f6994dfc
4 changed files with 52 additions and 35 deletions
|
@ -10,7 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/go-acme/lego/v4/certcrypto"
|
"github.com/go-acme/lego/v4/certcrypto"
|
||||||
"github.com/go-acme/lego/v4/lego"
|
"github.com/go-acme/lego/v4/lego"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
@ -30,17 +30,17 @@ type (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissingDomain = E.New("missing field 'domains'")
|
ErrMissingDomain = gperr.New("missing field 'domains'")
|
||||||
ErrMissingEmail = E.New("missing field 'email'")
|
ErrMissingEmail = gperr.New("missing field 'email'")
|
||||||
ErrMissingProvider = E.New("missing field 'provider'")
|
ErrMissingProvider = gperr.New("missing field 'provider'")
|
||||||
ErrInvalidDomain = E.New("invalid domain")
|
ErrInvalidDomain = gperr.New("invalid domain")
|
||||||
ErrUnknownProvider = E.New("unknown provider")
|
ErrUnknownProvider = gperr.New("unknown provider")
|
||||||
)
|
)
|
||||||
|
|
||||||
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
|
var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`)
|
||||||
|
|
||||||
// Validate implements the utils.CustomValidator interface.
|
// Validate implements the utils.CustomValidator interface.
|
||||||
func (cfg *AutocertConfig) Validate() E.Error {
|
func (cfg *AutocertConfig) Validate() gperr.Error {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -50,8 +50,8 @@ func (cfg *AutocertConfig) Validate() E.Error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
b := E.NewBuilder("autocert errors")
|
b := gperr.NewBuilder("autocert errors")
|
||||||
if cfg.Provider != ProviderLocal {
|
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
||||||
if len(cfg.Domains) == 0 {
|
if len(cfg.Domains) == 0 {
|
||||||
b.Add(ErrMissingDomain)
|
b.Add(ErrMissingDomain)
|
||||||
}
|
}
|
||||||
|
@ -79,7 +79,7 @@ func (cfg *AutocertConfig) Validate() E.Error {
|
||||||
return b.Error()
|
return b.Error()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) {
|
func (cfg *AutocertConfig) GetProvider() (*Provider, gperr.Error) {
|
||||||
if cfg == nil {
|
if cfg == nil {
|
||||||
cfg = new(AutocertConfig)
|
cfg = new(AutocertConfig)
|
||||||
}
|
}
|
||||||
|
@ -101,16 +101,16 @@ func (cfg *AutocertConfig) GetProvider() (*Provider, E.Error) {
|
||||||
var privKey *ecdsa.PrivateKey
|
var privKey *ecdsa.PrivateKey
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if cfg.Provider != ProviderLocal {
|
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
|
||||||
if privKey, err = cfg.loadACMEKey(); err != nil {
|
if privKey, err = cfg.loadACMEKey(); err != nil {
|
||||||
logging.Info().Err(err).Msg("load ACME private key failed")
|
logging.Info().Err(err).Msg("load ACME private key failed")
|
||||||
logging.Info().Msg("generate new ACME private key")
|
logging.Info().Msg("generate new ACME private key")
|
||||||
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
privKey, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, E.New("generate ACME private key").With(err)
|
return nil, gperr.New("generate ACME private key").With(err)
|
||||||
}
|
}
|
||||||
if err = cfg.saveACMEKey(privKey); err != nil {
|
if err = cfg.saveACMEKey(privKey); err != nil {
|
||||||
return nil, E.New("save ACME private key").With(err)
|
return nil, gperr.New("save ACME private key").With(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,6 +21,7 @@ const (
|
||||||
ProviderClouddns = "clouddns"
|
ProviderClouddns = "clouddns"
|
||||||
ProviderDuckdns = "duckdns"
|
ProviderDuckdns = "duckdns"
|
||||||
ProviderOVH = "ovh"
|
ProviderOVH = "ovh"
|
||||||
|
ProviderPseudo = "pseudo" // for testing
|
||||||
ProviderPorkbun = "porkbun"
|
ProviderPorkbun = "porkbun"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -30,5 +31,6 @@ var providersGenMap = map[string]ProviderGenerator{
|
||||||
ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
|
ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
|
||||||
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
|
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
|
||||||
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
|
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
|
||||||
|
ProviderPseudo: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig),
|
||||||
ProviderPorkbun: providerGenerator(porkbun.NewDefaultConfig, porkbun.NewDNSProviderConfig),
|
ProviderPorkbun: providerGenerator(porkbun.NewDefaultConfig, porkbun.NewDNSProviderConfig),
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,17 +4,19 @@ import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-acme/lego/v4/certificate"
|
"github.com/go-acme/lego/v4/certificate"
|
||||||
"github.com/go-acme/lego/v4/challenge"
|
"github.com/go-acme/lego/v4/challenge"
|
||||||
"github.com/go-acme/lego/v4/lego"
|
"github.com/go-acme/lego/v4/lego"
|
||||||
"github.com/go-acme/lego/v4/registration"
|
"github.com/go-acme/lego/v4/registration"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
U "github.com/yusing/go-proxy/internal/utils"
|
U "github.com/yusing/go-proxy/internal/utils"
|
||||||
|
@ -31,8 +33,10 @@ type (
|
||||||
legoCert *certificate.Resource
|
legoCert *certificate.Resource
|
||||||
tlsCert *tls.Certificate
|
tlsCert *tls.Certificate
|
||||||
certExpiries CertExpiries
|
certExpiries CertExpiries
|
||||||
|
|
||||||
|
obtainMu sync.Mutex
|
||||||
}
|
}
|
||||||
ProviderGenerator func(ProviderOpt) (challenge.Provider, E.Error)
|
ProviderGenerator func(ProviderOpt) (challenge.Provider, gperr.Error)
|
||||||
|
|
||||||
CertExpiries map[string]time.Time
|
CertExpiries map[string]time.Time
|
||||||
)
|
)
|
||||||
|
@ -62,11 +66,22 @@ func (p *Provider) GetExpiries() CertExpiries {
|
||||||
return p.certExpiries
|
return p.certExpiries
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) ObtainCert() E.Error {
|
func (p *Provider) ObtainCert() error {
|
||||||
if p.cfg.Provider == ProviderLocal {
|
if p.cfg.Provider == ProviderLocal {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if p.cfg.Provider == ProviderPseudo {
|
||||||
|
t := time.NewTicker(1000 * time.Millisecond)
|
||||||
|
defer t.Stop()
|
||||||
|
logging.Info().Msg("init client for pseudo provider")
|
||||||
|
<-t.C
|
||||||
|
logging.Info().Msg("registering acme for pseudo provider")
|
||||||
|
<-t.C
|
||||||
|
logging.Info().Msg("obtained cert for pseudo provider")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
if p.client == nil {
|
if p.client == nil {
|
||||||
if err := p.initClient(); err != nil {
|
if err := p.initClient(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -75,7 +90,7 @@ func (p *Provider) ObtainCert() E.Error {
|
||||||
|
|
||||||
if p.user.Registration == nil {
|
if p.user.Registration == nil {
|
||||||
if err := p.registerACME(); err != nil {
|
if err := p.registerACME(); err != nil {
|
||||||
return E.From(err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -100,22 +115,22 @@ func (p *Provider) ObtainCert() E.Error {
|
||||||
Bundle: true,
|
Bundle: true,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.From(err)
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = p.saveCert(cert); err != nil {
|
if err = p.saveCert(cert); err != nil {
|
||||||
return E.From(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
|
tlsCert, err := tls.X509KeyPair(cert.Certificate, cert.PrivateKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.From(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
expiries, err := getCertExpiries(&tlsCert)
|
expiries, err := getCertExpiries(&tlsCert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.From(err)
|
return err
|
||||||
}
|
}
|
||||||
p.tlsCert = &tlsCert
|
p.tlsCert = &tlsCert
|
||||||
p.certExpiries = expiries
|
p.certExpiries = expiries
|
||||||
|
@ -123,14 +138,14 @@ func (p *Provider) ObtainCert() E.Error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) LoadCert() E.Error {
|
func (p *Provider) LoadCert() error {
|
||||||
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
cert, err := tls.LoadX509KeyPair(p.cfg.CertPath, p.cfg.KeyPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Errorf("load SSL certificate: %w", err)
|
return fmt.Errorf("load SSL certificate: %w", err)
|
||||||
}
|
}
|
||||||
expiries, err := getCertExpiries(&cert)
|
expiries, err := getCertExpiries(&cert)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.Errorf("parse SSL certificate: %w", err)
|
return fmt.Errorf("parse SSL certificate: %w", err)
|
||||||
}
|
}
|
||||||
p.tlsCert = &cert
|
p.tlsCert = &cert
|
||||||
p.certExpiries = expiries
|
p.certExpiries = expiries
|
||||||
|
@ -149,7 +164,7 @@ func (p *Provider) ShouldRenewOn() time.Time {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
||||||
if p.GetName() == ProviderLocal {
|
if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -171,7 +186,7 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := p.renewIfNeeded(); err != nil {
|
if err := p.renewIfNeeded(); err != nil {
|
||||||
E.LogWarn("cert renew failed", err)
|
gperr.LogWarn("cert renew failed", err)
|
||||||
lastErrOn = time.Now()
|
lastErrOn = time.Now()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -184,10 +199,10 @@ func (p *Provider) ScheduleRenewal(parent task.Parent) {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) initClient() E.Error {
|
func (p *Provider) initClient() error {
|
||||||
legoClient, err := lego.NewClient(p.legoCfg)
|
legoClient, err := lego.NewClient(p.legoCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.From(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
generator := providersGenMap[p.cfg.Provider]
|
generator := providersGenMap[p.cfg.Provider]
|
||||||
|
@ -198,7 +213,7 @@ func (p *Provider) initClient() E.Error {
|
||||||
|
|
||||||
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
|
err = legoClient.Challenge.SetDNS01Provider(legoProvider)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return E.From(err)
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
p.client = legoClient
|
p.client = legoClient
|
||||||
|
@ -273,7 +288,7 @@ func (p *Provider) certState() CertState {
|
||||||
return CertStateValid
|
return CertStateValid
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Provider) renewIfNeeded() E.Error {
|
func (p *Provider) renewIfNeeded() error {
|
||||||
if p.cfg.Provider == ProviderLocal {
|
if p.cfg.Provider == ProviderLocal {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -312,13 +327,13 @@ func providerGenerator[CT any, PT challenge.Provider](
|
||||||
defaultCfg func() *CT,
|
defaultCfg func() *CT,
|
||||||
newProvider func(*CT) (PT, error),
|
newProvider func(*CT) (PT, error),
|
||||||
) ProviderGenerator {
|
) ProviderGenerator {
|
||||||
return func(opt ProviderOpt) (challenge.Provider, E.Error) {
|
return func(opt ProviderOpt) (challenge.Provider, gperr.Error) {
|
||||||
cfg := defaultCfg()
|
cfg := defaultCfg()
|
||||||
err := U.Deserialize(opt, &cfg)
|
err := U.Deserialize(opt, &cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
p, pErr := newProvider(cfg)
|
p, pErr := newProvider(cfg)
|
||||||
return p, E.From(pErr)
|
return p, gperr.Wrap(pErr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,16 @@
|
||||||
package autocert
|
package autocert
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *Provider) Setup() (err E.Error) {
|
func (p *Provider) Setup() (err error) {
|
||||||
if err = p.LoadCert(); err != nil {
|
if err = p.LoadCert(); err != nil {
|
||||||
if !err.Is(os.ErrNotExist) { // ignore if cert doesn't exist
|
if !errors.Is(err, os.ErrNotExist) { // ignore if cert doesn't exist
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
logging.Debug().Msg("obtaining cert due to error loading cert")
|
logging.Debug().Msg("obtaining cert due to error loading cert")
|
||||||
|
|
Loading…
Add table
Reference in a new issue