add cert info and renewal api

This commit is contained in:
yusing 2025-02-15 21:50:34 +08:00
parent 7129e2cc9d
commit 16b046bd44
15 changed files with 201 additions and 47 deletions

View file

@ -90,7 +90,7 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(ca)
if !ok {
return gperr.New("invalid CA certificate")
return gperr.New("invalid ca certificate")
}
cfg.tlsConfig = &tls.Config{
@ -128,21 +128,18 @@ func (cfg *AgentConfig) StartWithCerts(parent task.Parent, ca, crt, key []byte)
return nil
}
func (cfg *AgentConfig) Start(parent task.Parent) error {
func (cfg *AgentConfig) Start(parent task.Parent) gperr.Error {
certData, err := os.ReadFile(certs.AgentCertsFilename(cfg.Addr))
if err != nil {
if os.IsNotExist(err) {
return gperr.Errorf("agents certs not found, did you run `godoxy new-agent %s ...`?", cfg.Addr)
}
return gperr.Wrap(err)
return gperr.Wrap(err, "failed to read agent certs")
}
ca, crt, key, err := certs.ExtractCert(certData)
if err != nil {
return gperr.Wrap(err)
return gperr.Wrap(err, "failed to extract agent certs")
}
return cfg.StartWithCerts(parent, ca, crt, key)
return gperr.Wrap(cfg.StartWithCerts(parent, ca, crt, key))
}
func (cfg *AgentConfig) NewHTTPClient() *http.Client {

View file

@ -6,7 +6,6 @@ import (
"crypto/x509"
"encoding/pem"
"fmt"
"log"
"net"
"net/http"
"time"
@ -45,7 +44,6 @@ func StartAgentServer(parent task.Parent, opt Options) {
agentServer := &http.Server{
Handler: handler.NewAgentHandler(),
TLSConfig: tlsConfig,
ErrorLog: log.New(logger, "", 0),
}
go func() {

2
go.mod
View file

@ -3,7 +3,7 @@ module github.com/yusing/go-proxy
go 1.24.0
require (
github.com/PuerkitoBio/goquery v1.10.1
github.com/PuerkitoBio/goquery v1.10.2
github.com/coder/websocket v1.8.12
github.com/coreos/go-oidc/v3 v3.12.0
github.com/docker/cli v27.5.1+incompatible

4
go.sum
View file

@ -2,8 +2,8 @@ github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOEl
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/PuerkitoBio/goquery v1.10.1 h1:Y8JGYUkXWTGRB6Ars3+j3kN0xg1YqqlwvdTV8WTFQcU=
github.com/PuerkitoBio/goquery v1.10.1/go.mod h1:IYiHrOMps66ag56LEH7QYDDupKXyo5A8qrjIx3ZtujY=
github.com/PuerkitoBio/goquery v1.10.2 h1:7fh2BdHcG6VFZsK7toXBT/Bh1z5Wmy8Q9MV9HqT2AM8=
github.com/PuerkitoBio/goquery v1.10.2/go.mod h1:0guWGjcLu9AYC7C1GHnpysHy056u9aEkUHwhdnePMCU=
github.com/andybalholm/cascadia v1.3.3 h1:AG2YHrzJIm4BZ19iwJ/DAua6Btl3IwJX+VI4kktS1LM=
github.com/andybalholm/cascadia v1.3.3/go.mod h1:xNd9bqTn98Ln4DwST8/nG+H0yuB8Hmgu1YHNnWw0GeA=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=

View file

@ -7,6 +7,7 @@ import (
"github.com/prometheus/client_golang/prometheus/promhttp"
v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/auth"
"github.com/yusing/go-proxy/internal/api/v1/certapi"
"github.com/yusing/go-proxy/internal/api/v1/favicon"
"github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types"
@ -54,10 +55,13 @@ func (mux ServeMux) HandleFunc(methods, endpoint string, h any, requireAuth ...b
if len(requireAuth) > 0 && requireAuth[0] {
handler = auth.RequireAuth(handler)
}
if methods == "" {
mux.ServeMux.HandleFunc(endpoint, handler)
} else {
for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
}
}
}
func NewHandler(cfg config.ConfigInstance) http.Handler {
@ -82,6 +86,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
mux.HandleFunc("POST", "/v1/agents/add", v1.AddAgent, true)
mux.HandleFunc("GET", "/v1/metrics/system_info", v1.SystemInfo, true)
mux.HandleFunc("GET", "/v1/metrics/uptime", uptime.Poller.ServeHTTP, true)
mux.HandleFunc("GET", "/v1/cert/info", certapi.GetCertInfo, true)
mux.HandleFunc("", "/v1/cert/renew", certapi.RenewCert, true)
if common.PrometheusEnabled {
mux.Handle("GET /v1/metrics", promhttp.Handler())

View file

@ -0,0 +1,41 @@
package certapi
import (
"encoding/json"
"net/http"
config "github.com/yusing/go-proxy/internal/config/types"
)
type CertInfo struct {
Subject string `json:"subject"`
Issuer string `json:"issuer"`
NotBefore int64 `json:"not_before"`
NotAfter int64 `json:"not_after"`
DNSNames []string `json:"dns_names"`
EmailAddresses []string `json:"email_addresses"`
}
func GetCertInfo(w http.ResponseWriter, r *http.Request) {
autocert := config.GetInstance().AutoCertProvider()
if autocert == nil {
http.Error(w, "autocert is not enabled", http.StatusNotFound)
return
}
cert, err := autocert.GetCert(nil)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
certInfo := CertInfo{
Subject: cert.Leaf.Subject.CommonName,
Issuer: cert.Leaf.Issuer.CommonName,
NotBefore: cert.Leaf.NotBefore.Unix(),
NotAfter: cert.Leaf.NotAfter.Unix(),
DNSNames: cert.Leaf.DNSNames,
EmailAddresses: cert.Leaf.EmailAddresses,
}
json.NewEncoder(w).Encode(&certInfo)
}

View file

@ -0,0 +1,56 @@
package certapi
import (
"net/http"
config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/net/gphttp/gpwebsocket"
)
func RenewCert(w http.ResponseWriter, r *http.Request) {
autocert := config.GetInstance().AutoCertProvider()
if autocert == nil {
http.Error(w, "autocert is not enabled", http.StatusNotFound)
return
}
conn, err := gpwebsocket.Initiate(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
//nolint:errcheck
defer conn.CloseNow()
logs, cancel := memlogger.Events()
defer cancel()
done := make(chan struct{})
go func() {
defer close(done)
err = autocert.ObtainCert()
if err != nil {
gperr.LogError("failed to obtain cert", err)
gpwebsocket.WriteText(r, conn, err.Error())
} else {
logging.Info().Msg("cert obtained successfully")
}
}()
for {
select {
case l := <-logs:
if err != nil {
return
}
if !gpwebsocket.WriteText(r, conn, string(l)) {
return
}
case <-done:
return
}
}
}

View file

@ -51,7 +51,7 @@ func (cfg *AutocertConfig) Validate() gperr.Error {
}
b := gperr.NewBuilder("autocert errors")
if cfg.Provider != ProviderLocal {
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if len(cfg.Domains) == 0 {
b.Add(ErrMissingDomain)
}
@ -101,7 +101,7 @@ func (cfg *AutocertConfig) GetProvider() (*Provider, gperr.Error) {
var privKey *ecdsa.PrivateKey
var err error
if cfg.Provider != ProviderLocal {
if cfg.Provider != ProviderLocal && cfg.Provider != ProviderPseudo {
if privKey, err = cfg.loadACMEKey(); err != nil {
logging.Info().Err(err).Msg("load ACME private key failed")
logging.Info().Msg("generate new ACME private key")

View file

@ -20,6 +20,7 @@ const (
ProviderClouddns = "clouddns"
ProviderDuckdns = "duckdns"
ProviderOVH = "ovh"
ProviderPseudo = "pseudo" // for testing
)
var providersGenMap = map[string]ProviderGenerator{
@ -28,4 +29,5 @@ var providersGenMap = map[string]ProviderGenerator{
ProviderClouddns: providerGenerator(clouddns.NewDefaultConfig, clouddns.NewDNSProviderConfig),
ProviderDuckdns: providerGenerator(duckdns.NewDefaultConfig, duckdns.NewDNSProviderConfig),
ProviderOVH: providerGenerator(ovh.NewDefaultConfig, ovh.NewDNSProviderConfig),
ProviderPseudo: providerGenerator(NewDummyDefaultConfig, NewDummyDNSProviderConfig),
}

View file

@ -9,6 +9,7 @@ import (
"path"
"reflect"
"sort"
"sync"
"time"
"github.com/go-acme/lego/v4/certificate"
@ -32,6 +33,8 @@ type (
legoCert *certificate.Resource
tlsCert *tls.Certificate
certExpiries CertExpiries
obtainMu sync.Mutex
}
ProviderGenerator func(ProviderOpt) (challenge.Provider, gperr.Error)
@ -68,6 +71,17 @@ func (p *Provider) ObtainCert() error {
return nil
}
if p.cfg.Provider == ProviderPseudo {
t := time.NewTicker(1000 * time.Millisecond)
defer t.Stop()
logging.Info().Msg("init client for pseudo provider")
<-t.C
logging.Info().Msg("registering acme for pseudo provider")
<-t.C
logging.Info().Msg("obtained cert for pseudo provider")
return nil
}
if p.client == nil {
if err := p.initClient(); err != nil {
return err
@ -150,7 +164,7 @@ func (p *Provider) ShouldRenewOn() time.Time {
}
func (p *Provider) ScheduleRenewal(parent task.Parent) {
if p.GetName() == ProviderLocal {
if p.GetName() == ProviderLocal || p.GetName() == ProviderPseudo {
return
}
go func() {

View file

@ -51,10 +51,6 @@ You may run "ls-config" to show or dump the current config.`
var Validate = config.Validate
func GetInstance() *Config {
return config.GetInstance().(*Config)
}
func newConfig() *Config {
return &Config{
value: config.DefaultConfig(),
@ -75,7 +71,7 @@ func Load() (*Config, gperr.Error) {
}
func MatchDomains() []string {
return GetInstance().Value().MatchDomains
return config.GetInstance().Value().MatchDomains
}
func WatchChanges() {
@ -123,7 +119,7 @@ func Reload() gperr.Error {
// cancel all current subtasks -> wait
// -> replace config -> start new subtasks
GetInstance().Task().Finish("config changed")
config.GetInstance().(*Config).Task().Finish("config changed")
newCfg.Start(StartAllServers)
config.SetInstance(newCfg)
return nil

View file

@ -43,6 +43,7 @@ type (
GetAgent(agentAddrOrDockerHost string) (*agent.AgentConfig, bool)
AddAgent(host string, ca agent.PEMPair, client agent.PEMPair) (int, gperr.Error)
ListAgents() []*agent.AgentConfig
AutoCertProvider() *autocert.Provider
}
)

View file

@ -0,0 +1,42 @@
package types
import (
"testing"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/utils"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
func TestValidateConfig(t *testing.T) {
cases := []struct {
name string
data []byte
want gperr.Error
}{
{
name: "valid config",
data: []byte(`
autocert:
provider: local
`),
want: nil,
},
{
name: "unknown field",
data: []byte(`
autocert:
provider: local
unknown: true
`),
want: utils.ErrUnknownField,
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
got := Validate(c.data)
ExpectError(t, c.want, got)
})
}
}

View file

@ -90,35 +90,36 @@ func (m *memLogger) truncateIfNeeded(n int) {
}
func (m *memLogger) notifyWS(pos, n int) {
if m.connChans.Size() > 0 {
timeout := time.NewTimer(2 * time.Second)
if m.connChans.Size() == 0 && m.listeners.Size() == 0 {
return
}
timeout := time.NewTimer(3 * time.Second)
defer timeout.Stop()
m.notifyLock.RLock()
defer m.notifyLock.RUnlock()
m.connChans.Range(func(ch chan *logEntryRange, _ struct{}) bool {
select {
case ch <- &logEntryRange{pos, pos + n}:
return true
case <-timeout.C:
logging.Warn().Msg("mem logger: timeout logging to channel")
return false
}
})
if m.listeners.Size() > 0 {
msg := m.Buffer.Bytes()[pos : pos+n]
m.listeners.Range(func(ch chan []byte, _ struct{}) bool {
select {
case <-timeout.C:
logging.Warn().Msg("mem logger: timeout logging to channel")
return false
case ch <- msg:
return true
}
})
}
return
}
}
func (m *memLogger) writeBuf(b []byte) (pos int, err error) {