diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index a5aaa0a..125c608 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -4,30 +4,41 @@ import ( "context" "crypto/rand" "encoding/base64" + "encoding/json" "errors" "fmt" + "io" + "mime" "net/http" "net/url" "slices" + "strings" "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/net/gphttp" - CE "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils/strutils" "golang.org/x/oauth2" ) -type OIDCProvider struct { - oauthConfig *oauth2.Config - oidcProvider *oidc.Provider - oidcVerifier *oidc.IDTokenVerifier - oidcLogoutURL *url.URL - allowedUsers []string - allowedGroups []string - isMiddleware bool -} +type ( + OIDCProvider struct { + oauthConfig *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + oidcEndSessionURL *url.URL + allowedUsers []string + allowedGroups []string + isMiddleware bool + } + + providerJSON struct { + oidc.ProviderConfig + EndSessionURL string `json:"end_session_endpoint"` + } +) const CookieOauthState = "godoxy_oidc_state" @@ -36,25 +47,50 @@ const ( OIDCLogoutPath = "/auth/logout" ) -func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL, logoutURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { +func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { if len(allowedUsers)+len(allowedGroups) == 0 { return nil, errors.New("OIDC users, groups, or both must not be empty") } - var logout *url.URL - var err error - if logoutURL != "" { - logout, err = url.Parse(logoutURL) + wellKnown := strings.TrimSuffix(issuerURL, "/") + "/.well-known/openid-configuration" + resp, err := gphttp.Get(wellKnown) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("oidc: unable to read response body: %v", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("oidc: %s: %s", resp.Status, body) + } + + var p providerJSON + err = json.Unmarshal(body, &p) + if err != nil { + mimeType, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err == nil && mimeType != "application/json" { + return nil, fmt.Errorf("oidc: unexpected content type: %q from OIDC provider discovery, have you configured the correct issuer URL?", mimeType) + } + return nil, fmt.Errorf("oidc: failed to decode provider discovery object: %v", err) + } + + if p.IssuerURL != issuerURL { + return nil, fmt.Errorf("oidc: issuer did not match the issuer returned by provider, expected %q got %q", issuerURL, p.IssuerURL) + } + + var endSessionURL *url.URL + if p.EndSessionURL != "" { + endSessionURL, err = url.Parse(p.EndSessionURL) if err != nil { - return nil, fmt.Errorf("failed to parse logout URL: %w", err) + return nil, fmt.Errorf("oidc: failed to parse end session URL: %w", err) } } - provider, err := oidc.NewProvider(context.Background(), issuerURL) - if err != nil { - return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) - } - + provider := p.NewProvider(context.Background()) return &OIDCProvider{ oauthConfig: &oauth2.Config{ ClientID: clientID, @@ -67,9 +103,9 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL, logoutURL s oidcVerifier: provider.Verifier(&oidc.Config{ ClientID: clientID, }), - oidcLogoutURL: logout, - allowedUsers: allowedUsers, - allowedGroups: allowedGroups, + oidcEndSessionURL: endSessionURL, + allowedUsers: allowedUsers, + allowedGroups: allowedGroups, }, nil } @@ -80,7 +116,6 @@ func NewOIDCProviderFromEnv() (*OIDCProvider, error) { common.OIDCClientID, common.OIDCClientSecret, common.OIDCRedirectURL, - common.OIDCLogoutURL, common.OIDCAllowedUsers, common.OIDCAllowedGroups, ) @@ -130,7 +165,7 @@ func (auth *OIDCProvider) CheckToken(r *http.Request) error { // Logical AND between allowed users and groups. allowedUser := slices.Contains(auth.allowedUsers, claims.Username) - allowedGroup := len(CE.Intersect(claims.Groups, auth.allowedGroups)) > 0 + allowedGroup := len(utils.Intersect(claims.Groups, auth.allowedGroups)) > 0 if !allowedUser && !allowedGroup { return ErrUserNotAllowed } @@ -235,7 +270,7 @@ func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Re } func (auth *OIDCProvider) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) { - if auth.oidcLogoutURL == nil { + if auth.oidcEndSessionURL == nil { DefaultLogoutCallbackHandler(auth, w, r) return } @@ -247,7 +282,7 @@ func (auth *OIDCProvider) LogoutCallbackHandler(w http.ResponseWriter, r *http.R } clearTokenCookie(w, r, auth.TokenCookieName()) - logoutURL := *auth.oidcLogoutURL + logoutURL := *auth.oidcEndSessionURL logoutURL.Query().Add("id_token_hint", token.Value) http.Redirect(w, r, logoutURL.String(), http.StatusFound) diff --git a/internal/common/env.go b/internal/common/env.go index 7a10941..3cb9937 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -46,7 +46,6 @@ var ( // OIDC Configuration. OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") - OIDCLogoutURL = GetEnvString("OIDC_LOGOUT_URL", "") OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "") OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "") diff --git a/internal/net/gphttp/middleware/oidc.go b/internal/net/gphttp/middleware/oidc.go index 417469f..49191e9 100644 --- a/internal/net/gphttp/middleware/oidc.go +++ b/internal/net/gphttp/middleware/oidc.go @@ -80,11 +80,12 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce return false } + if r.URL.Path == auth.OIDCLogoutPath { + amw.auth.LogoutCallbackHandler(w, r) + } if err := amw.auth.CheckToken(r); err != nil { if errors.Is(err, auth.ErrMissingToken) { amw.authMux.ServeHTTP(w, r) - } else if r.URL.Path == auth.OIDCLogoutPath { - amw.auth.LogoutCallbackHandler(w, r) } else { auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath) }