From cba7338d8d9f7beac9c5ab68ec3dc31752fcf402 Mon Sep 17 00:00:00 2001 From: yusing Date: Fri, 28 Mar 2025 07:21:20 +0800 Subject: [PATCH] auth: support for end_session_endpoint discovery, remove OIDC_LOGOUT_URL --- internal/api/v1/auth/oidc.go | 87 ++++++++++++++++++++++--------- internal/api/v1/auth/oidc_test.go | 2 +- internal/common/env.go | 2 +- 3 files changed, 63 insertions(+), 28 deletions(-) diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index d9f33b0..125c608 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -4,11 +4,15 @@ import ( "context" "crypto/rand" "encoding/base64" + "encoding/json" "errors" "fmt" + "io" + "mime" "net/http" "net/url" "slices" + "strings" "time" "github.com/coreos/go-oidc/v3/oidc" @@ -19,15 +23,22 @@ import ( "golang.org/x/oauth2" ) -type OIDCProvider struct { - oauthConfig *oauth2.Config - oidcProvider *oidc.Provider - oidcVerifier *oidc.IDTokenVerifier - oidcLogoutURL *url.URL - allowedUsers []string - allowedGroups []string - isMiddleware bool -} +type ( + OIDCProvider struct { + oauthConfig *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + oidcEndSessionURL *url.URL + allowedUsers []string + allowedGroups []string + isMiddleware bool + } + + providerJSON struct { + oidc.ProviderConfig + EndSessionURL string `json:"end_session_endpoint"` + } +) const CookieOauthState = "godoxy_oidc_state" @@ -36,25 +47,50 @@ const ( OIDCLogoutPath = "/auth/logout" ) -func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL, logoutURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { +func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { if len(allowedUsers)+len(allowedGroups) == 0 { return nil, errors.New("OIDC users, groups, or both must not be empty") } - var logout *url.URL - var err error - if logoutURL != "" { - logout, err = url.Parse(logoutURL) + wellKnown := strings.TrimSuffix(issuerURL, "/") + "/.well-known/openid-configuration" + resp, err := gphttp.Get(wellKnown) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("oidc: unable to read response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("oidc: %s: %s", resp.Status, body) + } + + var p providerJSON + err = json.Unmarshal(body, &p) + if err != nil { + mimeType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err == nil && mimeType != "application/json" { + return nil, fmt.Errorf("oidc: unexpected content type: %q from OIDC provider discovery, have you configured the correct issuer URL?", mimeType) + } + return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err) + } + + if p.IssuerURL != issuerURL { + return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuerURL, p.IssuerURL) + } + + var endSessionURL *url.URL + if p.EndSessionURL != "" { + endSessionURL, err = url.Parse(p.EndSessionURL) if err != nil { - return nil, fmt.Errorf("failed to parse logout URL: %w", err) + return nil, fmt.Errorf("oidc: failed to parse end session URL: %w", err) } } - provider, err := oidc.NewProvider(context.Background(), issuerURL) - if err != nil { - return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) - } - + provider := p.NewProvider(context.Background()) return &OIDCProvider{ oauthConfig: &oauth2.Config{ ClientID: clientID, @@ -67,9 +103,9 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL, logoutURL s oidcVerifier: provider.Verifier(&oidc.Config{ ClientID: clientID, }), - oidcLogoutURL: logout, - allowedUsers: allowedUsers, - allowedGroups: allowedGroups, + oidcEndSessionURL: endSessionURL, + allowedUsers: allowedUsers, + allowedGroups: allowedGroups, }, nil } @@ -80,7 +116,6 @@ func NewOIDCProviderFromEnv() (*OIDCProvider, error) { common.OIDCClientID, common.OIDCClientSecret, common.OIDCRedirectURL, - common.OIDCLogoutURL, common.OIDCAllowedUsers, common.OIDCAllowedGroups, ) @@ -235,7 +270,7 @@ func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Re } func (auth *OIDCProvider) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) { - if auth.oidcLogoutURL == nil { + if auth.oidcEndSessionURL == nil { DefaultLogoutCallbackHandler(auth, w, r) return } @@ -247,7 +282,7 @@ func (auth *OIDCProvider) LogoutCallbackHandler(w http.ResponseWriter, r *http.R } clearTokenCookie(w, r, auth.TokenCookieName()) - logoutURL := *auth.oidcLogoutURL + logoutURL := *auth.oidcEndSessionURL logoutURL.Query().Add("id_token_hint", token.Value) http.Redirect(w, r, logoutURL.String(), http.StatusFound) diff --git a/internal/api/v1/auth/oidc_test.go b/internal/api/v1/auth/oidc_test.go index f249e8f..0ed759f 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/api/v1/auth/oidc_test.go @@ -306,7 +306,7 @@ func TestInitOIDC(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.logoutURL, tt.allowedUsers, tt.allowedGroups) + _, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers, tt.allowedGroups) if (err != nil) != tt.wantErr { t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/internal/common/env.go b/internal/common/env.go index f8f2b3f..52c6d67 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -44,11 +44,11 @@ var ( APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour) APIUser = GetEnvString("API_USER", "admin") APIPassword = GetEnvString("API_PASSWORD", "password") + DebugDisableAuth = GetEnvBool("DEBUG_DISABLE_AUTH", false) // OIDC Configuration. OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") - OIDCLogoutURL = GetEnvString("OIDC_LOGOUT_URL", "") OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "") OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")