diff --git a/bin/go-proxy b/bin/go-proxy index 2affb93..ac83950 100755 Binary files a/bin/go-proxy and b/bin/go-proxy differ diff --git a/src/go-proxy/autocert.go b/src/go-proxy/autocert.go index a5d8bda..cf11fca 100644 --- a/src/go-proxy/autocert.go +++ b/src/go-proxy/autocert.go @@ -15,16 +15,21 @@ import ( "github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certificate" + "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/providers/dns/cloudflare" "github.com/go-acme/lego/v4/registration" ) +type ProviderOptions = map[string]string +type ProviderGenerator = func(ProviderOptions) (challenge.Provider, error) +type CertExpiries = map[string]time.Time + type AutoCertConfig struct { Email string Domains []string `yaml:",flow"` Provider string - Options map[string]string `yaml:",flow"` + Options ProviderOptions `yaml:",flow"` } type AutoCertUser struct { @@ -46,11 +51,11 @@ func (u *AutoCertUser) GetPrivateKey() crypto.PrivateKey { type AutoCertProvider interface { GetCert(*tls.ClientHelloInfo) (*tls.Certificate, error) GetName() string - GetExpiry() time.Time + GetExpiries() CertExpiries LoadCert() bool ObtainCert() error - - needRenew() bool + RenewalOn() time.Time + ScheduleRenewal() } func (cfg AutoCertConfig) GetProvider() (AutoCertProvider, error) { @@ -78,58 +83,56 @@ func (cfg AutoCertConfig) GetProvider() (AutoCertProvider, error) { if err != nil { return nil, fmt.Errorf("unable to create lego client: %v", err) } - base := &AutoCertProviderBase{ + base := &autoCertProvider{ name: cfg.Provider, cfg: cfg, user: user, legoCfg: legoCfg, client: legoClient, } - switch cfg.Provider { - case "cloudflare": - return NewAutoCertCFProvider(base, cfg.Options) + gen, ok := providersGenMap[cfg.Provider] + if !ok { + return nil, fmt.Errorf("unknown provider: %s", cfg.Provider) } - return nil, fmt.Errorf("unknown provider: %s", cfg.Provider) + legoProvider, err := gen(cfg.Options) + if err != nil { + return nil, fmt.Errorf("unable to create provider: %v", err) + } + err = legoClient.Challenge.SetDNS01Provider(legoProvider) + if err != nil { + return nil, fmt.Errorf("unable to set challenge provider: %v", err) + } + return base, nil } -type AutoCertProviderBase struct { +type autoCertProvider struct { name string cfg AutoCertConfig user *AutoCertUser legoCfg *lego.Config client *lego.Client - tlsCert *tls.Certificate - expiry time.Time - mutex sync.Mutex + tlsCert *tls.Certificate + certExpiries CertExpiries + mutex sync.Mutex } -func (p *AutoCertProviderBase) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { +func (p *autoCertProvider) GetCert(_ *tls.ClientHelloInfo) (*tls.Certificate, error) { if p.tlsCert == nil { aclog.Fatal("no certificate available") } - if p.needRenew() { - p.mutex.Lock() - defer p.mutex.Unlock() - if p.needRenew() { - err := p.ObtainCert() - if err != nil { - return nil, err - } - } - } return p.tlsCert, nil } -func (p *AutoCertProviderBase) GetName() string { +func (p *autoCertProvider) GetName() string { return p.name } -func (p *AutoCertProviderBase) GetExpiry() time.Time { - return p.expiry +func (p *autoCertProvider) GetExpiries() CertExpiries { + return p.certExpiries } -func (p *AutoCertProviderBase) ObtainCert() error { +func (p *autoCertProvider) ObtainCert() error { client := p.client if p.user.Registration == nil { reg, err := client.Registration.Register(registration.RegisterOptions{TermsOfServiceAgreed: true}) @@ -154,30 +157,55 @@ func (p *AutoCertProviderBase) ObtainCert() error { if err != nil { return err } - p.tlsCert = &tlsCert - x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[len(tlsCert.Certificate)-1]) + expiries, err := getCertExpiries(&tlsCert) if err != nil { return err } - p.expiry = x509Cert.NotAfter + p.tlsCert = &tlsCert + p.certExpiries = expiries return nil } -func (p *AutoCertProviderBase) LoadCert() bool { +func (p *autoCertProvider) LoadCert() bool { cert, err := tls.LoadX509KeyPair(certFileDefault, keyFileDefault) if err != nil { return false } - x509Cert, err := x509.ParseCertificate(cert.Certificate[len(cert.Certificate)-1]) + expiries, err := getCertExpiries(&cert) if err != nil { return false } p.tlsCert = &cert - p.expiry = x509Cert.NotAfter + p.certExpiries = expiries + p.renewIfNeeded() return true } -func (p *AutoCertProviderBase) saveCert(cert *certificate.Resource) error { +func (p *autoCertProvider) RenewalOn() time.Time { + t := time.Now().AddDate(0, 0, 3) + for _, expiry := range p.certExpiries { + if expiry.Before(t) { + return time.Now() + } + return t + } + // this line should never be reached + panic("no certificate available") +} + +func (p *autoCertProvider) ScheduleRenewal() { + for { + t := time.Until(p.RenewalOn()) + aclog.Infof("next renewal in %v", t) + time.Sleep(t) + err := p.renewIfNeeded() + if err != nil { + aclog.Fatal(err) + } + } +} + +func (p *autoCertProvider) saveCert(cert *certificate.Resource) error { err := os.MkdirAll(path.Dir(certFileDefault), 0644) if err != nil { return fmt.Errorf("unable to create cert directory: %v", err) @@ -193,36 +221,68 @@ func (p *AutoCertProviderBase) saveCert(cert *certificate.Resource) error { return nil } -func (p *AutoCertProviderBase) needRenew() bool { - return p.expiry.Before(time.Now().Add(24 * time.Hour)) +func (p *autoCertProvider) needRenewal() bool { + return time.Now().After(p.RenewalOn()) } -type AutoCertCFProvider struct { - *AutoCertProviderBase - *cloudflare.Config +func (p *autoCertProvider) renewIfNeeded() error { + if !p.needRenewal() { + return nil + } + + p.mutex.Lock() + defer p.mutex.Unlock() + + if !p.needRenewal() { + return nil + } + + trials := 0 + for { + err := p.ObtainCert() + if err == nil { + return nil + } + trials++ + if trials > 3 { + return fmt.Errorf("unable to renew certificate: %v after 3 trials", err) + } + aclog.Errorf("failed to renew certificate: %v, trying again in 5 seconds", err) + time.Sleep(5 * time.Second) + } } -func NewAutoCertCFProvider(base *AutoCertProviderBase, opt map[string]string) (*AutoCertCFProvider, error) { - p := &AutoCertCFProvider{ - base, - cloudflare.NewDefaultConfig(), +func providerGenerator[CT interface{}, PT challenge.Provider](defaultCfg func() *CT, newProvider func(*CT) (PT, error)) ProviderGenerator { + return func(opt ProviderOptions) (challenge.Provider, error) { + cfg := defaultCfg() + err := setOptions(cfg, opt) + if err != nil { + return nil, err + } + p, err := newProvider(cfg) + if err != nil { + return nil, err + } + return p, nil } - err := setOptions(p.Config, opt) - if err != nil { - return nil, err - } - legoProvider, err := cloudflare.NewDNSProviderConfig(p.Config) - if err != nil { - return nil, fmt.Errorf("unable to create cloudflare provider: %v", err) - } - err = p.client.Challenge.SetDNS01Provider(legoProvider) - if err != nil { - return nil, fmt.Errorf("unable to set challenge provider: %v", err) - } - return p, nil } -func setOptions[T interface{}](cfg *T, opt map[string]string) error { +func getCertExpiries(cert *tls.Certificate) (CertExpiries, error) { + r := make(CertExpiries, len(cert.Certificate)) + for _, cert := range cert.Certificate { + x509Cert, err := x509.ParseCertificate(cert) + if err != nil { + return nil, err + } + if x509Cert.IsCA { + continue + } + r[x509Cert.Subject.CommonName] = x509Cert.NotAfter + } + return r, nil +} + +func setOptions[T interface{}](cfg *T, opt ProviderOptions) error { for k, v := range opt { err := SetFieldFromSnake(cfg, k, v) if err != nil { @@ -231,3 +291,7 @@ func setOptions[T interface{}](cfg *T, opt map[string]string) error { } return nil } + +var providersGenMap = map[string]ProviderGenerator{ + "cloudflare": providerGenerator(cloudflare.NewDefaultConfig, cloudflare.NewDNSProviderConfig), +} diff --git a/src/go-proxy/main.go b/src/go-proxy/main.go index 8ac70f6..991a270 100755 --- a/src/go-proxy/main.go +++ b/src/go-proxy/main.go @@ -11,7 +11,6 @@ import ( ) func main() { - // flag.Parse() runtime.GOMAXPROCS(runtime.NumCPU()) logrus.SetFormatter(&logrus.TextFormatter{ @@ -52,7 +51,10 @@ func main() { aclog.Fatal("error obtaining certificate ", err) } } - aclog.Infof("certificate will be expired at %v and get renewed", autoCertProvider.GetExpiry()) + for name, expiry := range autoCertProvider.GetExpiries() { + aclog.Infof("certificate %q: expire on %v", name, expiry) + } + go autoCertProvider.ScheduleRenewal() } proxyServer = NewServer( "proxy", @@ -86,9 +88,8 @@ func main() { signal.Notify(sig, syscall.SIGHUP) <-sig - cfg.StopWatching() - StopFSWatcher() - StopDockerWatcher() + // cfg.StopWatching() + cfg.StopProviders() panelServer.Stop() proxyServer.Stop() diff --git a/src/go-proxy/server.go b/src/go-proxy/server.go index 692cf80..09847e5 100644 --- a/src/go-proxy/server.go +++ b/src/go-proxy/server.go @@ -79,11 +79,13 @@ func (s *Server) Stop() { if s.httpStarted { errHTTP := s.http.Shutdown(ctx) s.handleErr("http", errHTTP) + s.httpStarted = false } if s.httpsStarted { errHTTPS := s.https.Shutdown(ctx) s.handleErr("https", errHTTPS) + s.httpsStarted = false } } diff --git a/src/go-proxy/watcher.go b/src/go-proxy/watcher.go index d5c0137..8776ec2 100644 --- a/src/go-proxy/watcher.go +++ b/src/go-proxy/watcher.go @@ -127,12 +127,12 @@ func InitFSWatcher() { func InitDockerWatcher() { // stop all docker client on watcher stop go func() { - defer dockerWatcherWg.Done() <-dockerWatcherStop ParallelForEachValue( dockerWatchMap.Iterator(), (*dockerWatcher).Dispose, ) + dockerWatcherWg.Done() }() }