diff --git a/cmd/main.go b/cmd/main.go index d021457..eb46cc1 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -127,11 +127,18 @@ func main() { return } + cfg.Start(&config.StartServersOptions{ + Proxy: true, + Metrics: true, + }) if err := auth.Initialize(); err != nil { logging.Fatal().Err(err).Msg("failed to initialize authentication") } + // API Handler needs to start after auth is initialized. + cfg.StartServers(&config.StartServersOptions{ + API: true, + }) - cfg.Start() config.WatchChanges() sig := make(chan os.Signal, 1) diff --git a/internal/config/config.go b/internal/config/config.go index 8e96fe6..aa8e07e 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -150,10 +150,10 @@ func (cfg *Config) Context() context.Context { return cfg.task.Context() } -func (cfg *Config) Start() { +func (cfg *Config) Start(opts ...*StartServersOptions) { cfg.StartAutoCert() cfg.StartProxyProviders() - cfg.StartServers() + cfg.StartServers(opts...) } func (cfg *Config) StartAutoCert() { @@ -187,7 +187,7 @@ type StartServersOptions struct { func (cfg *Config) StartServers(opts ...*StartServersOptions) { if len(opts) == 0 { - opts = append(opts, &StartServersOptions{Proxy: true, API: true, Metrics: true}) + opts = append(opts, &StartServersOptions{}) } opt := opts[0] if opt.Proxy { diff --git a/internal/net/http/middleware/oidc.go b/internal/net/http/middleware/oidc.go index 3b0adc9..3af1ca3 100644 --- a/internal/net/http/middleware/oidc.go +++ b/internal/net/http/middleware/oidc.go @@ -2,6 +2,8 @@ package middleware import ( "net/http" + "sync" + "sync/atomic" "github.com/yusing/go-proxy/internal/api/v1/auth" E "github.com/yusing/go-proxy/internal/error" @@ -13,6 +15,9 @@ type oidcMiddleware struct { auth auth.Provider authMux *http.ServeMux + + isInitialized int32 + initMu sync.Mutex } var OIDC = NewMiddleware[oidcMiddleware]() @@ -21,6 +26,29 @@ func (amw *oidcMiddleware) finalize() error { if !auth.IsOIDCEnabled() { return E.New("OIDC not enabled but OIDC middleware is used") } + return nil +} + +func (amw *oidcMiddleware) init() error { + if atomic.LoadInt32(&amw.isInitialized) == 1 { + return nil + } + + return amw.initSlow() +} + +func (amw *oidcMiddleware) initSlow() error { + amw.initMu.Lock() + if amw.isInitialized == 1 { + amw.initMu.Unlock() + return nil + } + + defer func() { + amw.isInitialized = 1 + amw.initMu.Unlock() + }() + authProvider, err := auth.NewOIDCProviderFromEnv() if err != nil { return err @@ -45,6 +73,12 @@ func (amw *oidcMiddleware) finalize() error { } func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + if err := amw.init(); err != nil { + // no need to log here, main OIDC may already failed and logged + http.Error(w, err.Error(), http.StatusInternalServerError) + return false + } + if err := amw.auth.CheckToken(r); err != nil { amw.authMux.ServeHTTP(w, r) return false