diff --git a/internal/autocert/config.go b/internal/autocert/config.go index 0947df8..a71d429 100644 --- a/internal/autocert/config.go +++ b/internal/autocert/config.go @@ -5,12 +5,15 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/x509" + "net/http" "os" "regexp" "github.com/go-acme/lego/v4/certcrypto" + "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/lego" "github.com/rs/zerolog/log" + "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/utils" ) @@ -22,13 +25,19 @@ type Config struct { KeyPath string `json:"key_path,omitempty"` ACMEKeyPath string `json:"acme_key_path,omitempty"` Provider string `json:"provider,omitempty"` + CADirURL string `json:"ca_dir_url,omitempty"` Options map[string]any `json:"options,omitempty"` + + HTTPClient *http.Client `json:"-"` // for tests only + + challengeProvider challenge.Provider } var ( ErrMissingDomain = gperr.New("missing field 'domains'") ErrMissingEmail = gperr.New("missing field 'email'") ErrMissingProvider = gperr.New("missing field 'provider'") + ErrMissingCADirURL = gperr.New("missing field 'ca_dir_url'") ErrInvalidDomain = gperr.New("invalid domain") ErrUnknownProvider = gperr.New("unknown provider") ) @@ -36,6 +45,7 @@ var ( const ( ProviderLocal = "local" ProviderPseudo = "pseudo" + ProviderCustom = "custom" ) var domainOrWildcardRE = regexp.MustCompile(`^\*?([^.]+\.)+[^.]+$`) @@ -52,6 +62,10 @@ func (cfg *Config) Validate() gperr.Error { } b := gperr.NewBuilder("autocert errors") + if cfg.Provider == ProviderCustom && cfg.CADirURL == "" { + b.Add(ErrMissingCADirURL) + } + if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo { if len(cfg.Domains) == 0 { b.Add(ErrMissingDomain) @@ -59,24 +73,34 @@ func (cfg *Config) Validate() gperr.Error { if cfg.Email == "" { b.Add(ErrMissingEmail) } - for i, d := range cfg.Domains { - if !domainOrWildcardRE.MatchString(d) { - b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i)) + if cfg.Provider != ProviderCustom { + for i, d := range cfg.Domains { + if !domainOrWildcardRE.MatchString(d) { + b.Add(ErrInvalidDomain.Subjectf("domains[%d]", i)) + } } } // check if provider is implemented providerConstructor, ok := Providers[cfg.Provider] if !ok { - b.Add(ErrUnknownProvider. - Subject(cfg.Provider). - With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers)))) + if cfg.Provider != ProviderCustom { + b.Add(ErrUnknownProvider. + Subject(cfg.Provider). + With(gperr.DoYouMean(utils.NearestField(cfg.Provider, Providers)))) + } } else { - _, err := providerConstructor(cfg.Options) + provider, err := providerConstructor(cfg.Options) if err != nil { b.Add(err) + } else { + cfg.challengeProvider = provider } } } + + if cfg.challengeProvider == nil { + cfg.challengeProvider, _ = Providers[ProviderLocal](nil) + } return b.Error() } @@ -119,10 +143,21 @@ func (cfg *Config) GetLegoConfig() (*User, *lego.Config, gperr.Error) { legoCfg := lego.NewConfig(user) legoCfg.Certificate.KeyType = certcrypto.EC256 + if cfg.HTTPClient != nil { + legoCfg.HTTPClient = cfg.HTTPClient + } + + if cfg.CADirURL != "" { + legoCfg.CADirURL = cfg.CADirURL + } + return user, legoCfg, nil } func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) { + if common.IsTest { + return nil, os.ErrNotExist + } data, err := os.ReadFile(cfg.ACMEKeyPath) if err != nil { return nil, err @@ -131,6 +166,9 @@ func (cfg *Config) LoadACMEKey() (*ecdsa.PrivateKey, error) { } func (cfg *Config) SaveACMEKey(key *ecdsa.PrivateKey) error { + if common.IsTest { + return nil + } data, err := x509.MarshalECPrivateKey(key) if err != nil { return err diff --git a/internal/autocert/provider.go b/internal/autocert/provider.go index 58e9afa..12c0125 100644 --- a/internal/autocert/provider.go +++ b/internal/autocert/provider.go @@ -17,6 +17,7 @@ import ( "github.com/go-acme/lego/v4/registration" "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/notif" "github.com/yusing/go-proxy/internal/task" @@ -77,12 +78,10 @@ func (p *Provider) ObtainCert() error { } if p.cfg.Provider == ProviderPseudo { - t := time.NewTicker(1000 * time.Millisecond) - defer t.Stop() log.Info().Msg("init client for pseudo provider") - <-t.C + <-time.After(time.Second) log.Info().Msg("registering acme for pseudo provider") - <-t.C + <-time.After(time.Second) log.Info().Msg("obtained cert for pseudo provider") return nil } @@ -220,13 +219,7 @@ func (p *Provider) initClient() error { return err } - generator := Providers[p.cfg.Provider] - legoProvider, pErr := generator(p.cfg.Options) - if pErr != nil { - return pErr - } - - err = legoClient.Challenge.SetDNS01Provider(legoProvider) + err = legoClient.Challenge.SetDNS01Provider(p.cfg.challengeProvider) if err != nil { return err } @@ -255,6 +248,9 @@ func (p *Provider) registerACME() error { } func (p *Provider) saveCert(cert *certificate.Resource) error { + if common.IsTest { + return nil + } /* This should have been done in setup but double check is always a good choice.*/ _, err := os.Stat(path.Dir(p.cfg.CertPath)) diff --git a/internal/autocert/provider_test/custom_test.go b/internal/autocert/provider_test/custom_test.go new file mode 100644 index 0000000..af7e1c2 --- /dev/null +++ b/internal/autocert/provider_test/custom_test.go @@ -0,0 +1,453 @@ +//nolint:errchkjson,errcheck +package provider_test + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "encoding/pem" + "io" + "math/big" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/yusing/go-proxy/internal/autocert" + "github.com/yusing/go-proxy/internal/dnsproviders" +) + +func TestMain(m *testing.M) { + dnsproviders.InitProviders() + m.Run() +} + +func TestCustomProvider(t *testing.T) { + t.Run("valid custom provider with step-ca", func(t *testing.T) { + cfg := &autocert.Config{ + Email: "test@example.com", + Domains: []string{"example.com", "*.example.com"}, + Provider: autocert.ProviderCustom, + CADirURL: "https://ca.example.com:9000/acme/acme/directory", + CertPath: "certs/custom.crt", + KeyPath: "certs/custom.key", + ACMEKeyPath: "certs/custom-acme.key", + } + + err := cfg.Validate() + require.NoError(t, err) + + user, legoCfg, err := cfg.GetLegoConfig() + require.NoError(t, err) + require.NotNil(t, user) + require.NotNil(t, legoCfg) + require.Equal(t, "https://ca.example.com:9000/acme/acme/directory", legoCfg.CADirURL) + require.Equal(t, "test@example.com", user.Email) + }) + + t.Run("custom provider missing CADirURL", func(t *testing.T) { + cfg := &autocert.Config{ + Email: "test@example.com", + Domains: []string{"example.com"}, + Provider: autocert.ProviderCustom, + // CADirURL is missing + } + + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), "missing field 'ca_dir_url'") + }) + + t.Run("custom provider with step-ca internal CA", func(t *testing.T) { + cfg := &autocert.Config{ + Email: "admin@internal.com", + Domains: []string{"internal.example.com", "api.internal.example.com"}, + Provider: autocert.ProviderCustom, + CADirURL: "https://step-ca.internal:443/acme/acme/directory", + CertPath: "certs/internal.crt", + KeyPath: "certs/internal.key", + ACMEKeyPath: "certs/internal-acme.key", + } + + err := cfg.Validate() + require.NoError(t, err) + + user, legoCfg, err := cfg.GetLegoConfig() + require.NoError(t, err) + require.NotNil(t, user) + require.NotNil(t, legoCfg) + require.Equal(t, "https://step-ca.internal:443/acme/acme/directory", legoCfg.CADirURL) + require.Equal(t, "admin@internal.com", user.Email) + + provider := autocert.NewProvider(cfg, user, legoCfg) + require.NotNil(t, provider) + require.Equal(t, autocert.ProviderCustom, provider.GetName()) + require.Equal(t, "certs/internal.crt", provider.GetCertPath()) + require.Equal(t, "certs/internal.key", provider.GetKeyPath()) + }) +} + +func TestObtainCertFromCustomProvider(t *testing.T) { + // Create a test ACME server + acmeServer := newTestACMEServer(t) + defer acmeServer.Close() + + t.Run("obtain cert from custom step-ca server", func(t *testing.T) { + cfg := &autocert.Config{ + Email: "test@example.com", + Domains: []string{"test.example.com"}, + Provider: autocert.ProviderCustom, + CADirURL: acmeServer.URL() + "/acme/acme/directory", + CertPath: "certs/stepca-test.crt", + KeyPath: "certs/stepca-test.key", + ACMEKeyPath: "certs/stepca-test-acme.key", + HTTPClient: acmeServer.httpClient(), + } + + err := error(cfg.Validate()) + require.NoError(t, err) + + user, legoCfg, err := cfg.GetLegoConfig() + require.NoError(t, err) + require.NotNil(t, user) + require.NotNil(t, legoCfg) + + provider := autocert.NewProvider(cfg, user, legoCfg) + require.NotNil(t, provider) + + // Test obtaining certificate + err = provider.ObtainCert() + require.NoError(t, err) + + // Verify certificate was obtained + cert, err := provider.GetCert(nil) + require.NoError(t, err) + require.NotNil(t, cert) + + // Verify certificate properties + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + require.NoError(t, err) + require.Contains(t, x509Cert.DNSNames, "test.example.com") + require.True(t, time.Now().Before(x509Cert.NotAfter)) + require.True(t, time.Now().After(x509Cert.NotBefore)) + }) +} + +// testACMEServer implements a minimal ACME server for testing. +type testACMEServer struct { + server *httptest.Server + caCert *x509.Certificate + caKey *rsa.PrivateKey + clientCSRs map[string]*x509.CertificateRequest + orderID string +} + +func newTestACMEServer(t *testing.T) *testACMEServer { + t.Helper() + + // Generate CA certificate and key + caKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + + caTemplate := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{ + Organization: []string{"Test CA"}, + Country: []string{"US"}, + Province: []string{""}, + Locality: []string{"Test"}, + StreetAddress: []string{""}, + PostalCode: []string{""}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caCertDER, err := x509.CreateCertificate(rand.Reader, caTemplate, caTemplate, &caKey.PublicKey, caKey) + require.NoError(t, err) + + caCert, err := x509.ParseCertificate(caCertDER) + require.NoError(t, err) + + acme := &testACMEServer{ + caCert: caCert, + caKey: caKey, + clientCSRs: make(map[string]*x509.CertificateRequest), + orderID: "test-order-123", + } + + mux := http.NewServeMux() + acme.setupRoutes(mux) + + acme.server = httptest.NewTLSServer(mux) + return acme +} + +func (s *testACMEServer) Close() { + s.server.Close() +} + +func (s *testACMEServer) URL() string { + return s.server.URL +} + +func (s *testACMEServer) httpClient() *http.Client { + return &http.Client{ + Transport: &http.Transport{ + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + TLSHandshakeTimeout: 30 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: true, + }, + }, + } +} + +func (s *testACMEServer) setupRoutes(mux *http.ServeMux) { + // ACME directory endpoint + mux.HandleFunc("/acme/acme/directory", s.handleDirectory) + + // ACME endpoints + mux.HandleFunc("/acme/new-nonce", s.handleNewNonce) + mux.HandleFunc("/acme/new-account", s.handleNewAccount) + mux.HandleFunc("/acme/new-order", s.handleNewOrder) + mux.HandleFunc("/acme/authz/", s.handleAuthorization) + mux.HandleFunc("/acme/chall/", s.handleChallenge) + mux.HandleFunc("/acme/order/", s.handleOrder) + mux.HandleFunc("/acme/cert/", s.handleCertificate) +} + +func (s *testACMEServer) handleDirectory(w http.ResponseWriter, r *http.Request) { + directory := map[string]interface{}{ + "newNonce": s.server.URL + "/acme/new-nonce", + "newAccount": s.server.URL + "/acme/new-account", + "newOrder": s.server.URL + "/acme/new-order", + "keyChange": s.server.URL + "/acme/key-change", + "meta": map[string]interface{}{ + "termsOfService": s.server.URL + "/terms", + }, + } + + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(directory) +} + +func (s *testACMEServer) handleNewNonce(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Replay-Nonce", "test-nonce-12345") + w.WriteHeader(http.StatusOK) +} + +func (s *testACMEServer) handleNewAccount(w http.ResponseWriter, r *http.Request) { + account := map[string]interface{}{ + "status": "valid", + "contact": []string{"mailto:test@example.com"}, + "orders": s.server.URL + "/acme/orders", + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", s.server.URL+"/acme/account/1") + w.Header().Set("Replay-Nonce", "test-nonce-67890") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(account) +} + +func (s *testACMEServer) handleNewOrder(w http.ResponseWriter, r *http.Request) { + authzID := "test-authz-456" + + order := map[string]interface{}{ + "status": "ready", // Skip pending state for simplicity + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}}, + "authorizations": []string{s.server.URL + "/acme/authz/" + authzID}, + "finalize": s.server.URL + "/acme/order/" + s.orderID + "/finalize", + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", s.server.URL+"/acme/order/"+s.orderID) + w.Header().Set("Replay-Nonce", "test-nonce-order") + w.WriteHeader(http.StatusCreated) + json.NewEncoder(w).Encode(order) +} + +func (s *testACMEServer) handleAuthorization(w http.ResponseWriter, r *http.Request) { + authz := map[string]interface{}{ + "status": "valid", // Skip challenge validation for simplicity + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifier": map[string]string{"type": "dns", "value": "test.example.com"}, + "challenges": []map[string]interface{}{ + { + "type": "dns-01", + "status": "valid", + "url": s.server.URL + "/acme/chall/test-chall-789", + "token": "test-token-abc123", + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Replay-Nonce", "test-nonce-authz") + json.NewEncoder(w).Encode(authz) +} + +func (s *testACMEServer) handleChallenge(w http.ResponseWriter, r *http.Request) { + challenge := map[string]interface{}{ + "type": "dns-01", + "status": "valid", + "url": r.URL.String(), + "token": "test-token-abc123", + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Replay-Nonce", "test-nonce-chall") + json.NewEncoder(w).Encode(challenge) +} + +func (s *testACMEServer) handleOrder(w http.ResponseWriter, r *http.Request) { + if strings.HasSuffix(r.URL.Path, "/finalize") { + s.handleFinalize(w, r) + return + } + + certURL := s.server.URL + "/acme/cert/" + s.orderID + order := map[string]interface{}{ + "status": "valid", + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}}, + "certificate": certURL, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Replay-Nonce", "test-nonce-order-get") + json.NewEncoder(w).Encode(order) +} + +func (s *testACMEServer) handleFinalize(w http.ResponseWriter, r *http.Request) { + // Read the JWS payload + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "Failed to read request", http.StatusBadRequest) + return + } + + // Extract CSR from JWS payload + csr, err := s.extractCSRFromJWS(body) + if err != nil { + http.Error(w, "Invalid CSR: "+err.Error(), http.StatusBadRequest) + return + } + + // Store the CSR for certificate generation + s.clientCSRs[s.orderID] = csr + + certURL := s.server.URL + "/acme/cert/" + s.orderID + order := map[string]interface{}{ + "status": "valid", + "expires": time.Now().Add(24 * time.Hour).Format(time.RFC3339), + "identifiers": []map[string]string{{"type": "dns", "value": "test.example.com"}}, + "certificate": certURL, + } + + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Location", strings.TrimSuffix(r.URL.String(), "/finalize")) + w.Header().Set("Replay-Nonce", "test-nonce-finalize") + json.NewEncoder(w).Encode(order) +} + +func (s *testACMEServer) extractCSRFromJWS(jwsData []byte) (*x509.CertificateRequest, error) { + // Parse the JWS structure + var jws struct { + Protected string `json:"protected"` + Payload string `json:"payload"` + Signature string `json:"signature"` + } + + if err := json.Unmarshal(jwsData, &jws); err != nil { + return nil, err + } + + // Decode the payload + payloadBytes, err := base64.RawURLEncoding.DecodeString(jws.Payload) + if err != nil { + return nil, err + } + + // Parse the finalize request + var finalizeReq struct { + CSR string `json:"csr"` + } + + if err := json.Unmarshal(payloadBytes, &finalizeReq); err != nil { + return nil, err + } + + // Decode the CSR + csrBytes, err := base64.RawURLEncoding.DecodeString(finalizeReq.CSR) + if err != nil { + return nil, err + } + + // Parse the CSR + csr, err := x509.ParseCertificateRequest(csrBytes) + if err != nil { + return nil, err + } + + return csr, nil +} + +func (s *testACMEServer) handleCertificate(w http.ResponseWriter, r *http.Request) { + // Extract order ID from URL + orderID := strings.TrimPrefix(r.URL.Path, "/acme/cert/") + + // Get the CSR for this order + csr, exists := s.clientCSRs[orderID] + if !exists { + http.Error(w, "No CSR found for order", http.StatusBadRequest) + return + } + + // Create certificate using the public key from the client's CSR + template := &x509.Certificate{ + SerialNumber: big.NewInt(2), + Subject: pkix.Name{ + Organization: []string{"Test Cert"}, + Country: []string{"US"}, + }, + DNSNames: csr.DNSNames, + NotBefore: time.Now(), + NotAfter: time.Now().Add(90 * 24 * time.Hour), + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + } + + // Use the public key from the CSR and sign with CA key + certDER, err := x509.CreateCertificate(rand.Reader, template, s.caCert, csr.PublicKey, s.caKey) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // Return certificate chain + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + caPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: s.caCert.Raw}) + + w.Header().Set("Content-Type", "application/pem-certificate-chain") + w.Header().Set("Replay-Nonce", "test-nonce-cert") + w.Write(append(certPEM, caPEM...)) +} diff --git a/internal/autocert/providers.go b/internal/autocert/providers.go index 222c834..5b73e39 100644 --- a/internal/autocert/providers.go +++ b/internal/autocert/providers.go @@ -16,9 +16,11 @@ func DNSProvider[CT any, PT challenge.Provider]( ) Generator { return func(opt map[string]any) (challenge.Provider, gperr.Error) { cfg := defaultCfg() - err := serialization.MapUnmarshalValidate(opt, &cfg) - if err != nil { - return nil, err + if len(opt) > 0 { + err := serialization.MapUnmarshalValidate(opt, &cfg) + if err != nil { + return nil, err + } } p, pErr := newProvider(cfg) return p, gperr.Wrap(pErr)