From c5fd21552eb96af4358358f2f992e2993780adc1 Mon Sep 17 00:00:00 2001 From: yusing Date: Mon, 28 Apr 2025 11:19:57 +0800 Subject: [PATCH] fix(oidc): token not being refreshed when receiving simutaneous requests from the same session --- internal/auth/auth.go | 37 ++++++++++---- internal/auth/oauth_refresh.go | 93 ++++++++++++++++++++++++---------- internal/auth/oidc.go | 25 ++++++--- internal/auth/utils.go | 13 +++-- 4 files changed, 119 insertions(+), 49 deletions(-) diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 6e8d231..14c7dbf 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -1,6 +1,7 @@ package auth import ( + "context" "net/http" "github.com/yusing/go-proxy/internal/common" @@ -38,17 +39,35 @@ func IsOIDCEnabled() bool { return common.OIDCIssuerURL != "" } +type nextHandler struct{} + +var nextHandlerContextKey = nextHandler{} + func RequireAuth(next http.HandlerFunc) http.HandlerFunc { - if IsEnabled() { - return func(w http.ResponseWriter, r *http.Request) { - if err := defaultAuth.CheckToken(r); err != nil { - gphttp.ClientError(w, err, http.StatusUnauthorized) - } else { - next(w, r) - } - } + if !IsEnabled() { + return next + } + return func(w http.ResponseWriter, r *http.Request) { + if err := defaultAuth.CheckToken(r); err != nil { + if IsFrontend(r) { + r = r.WithContext(context.WithValue(r.Context(), nextHandlerContextKey, next)) + defaultAuth.LoginHandler(w, r) + } else { + gphttp.ClientError(w, err, http.StatusUnauthorized) + } + return + } + next(w, r) + } +} + +func ProceedNext(w http.ResponseWriter, r *http.Request) { + next, ok := r.Context().Value(nextHandlerContextKey).(http.HandlerFunc) + if ok { + next(w, r) + } else { + w.WriteHeader(http.StatusOK) } - return next } func AuthCheckHandler(w http.ResponseWriter, r *http.Request) { diff --git a/internal/auth/oauth_refresh.go b/internal/auth/oauth_refresh.go index 800680b..017d046 100644 --- a/internal/auth/oauth_refresh.go +++ b/internal/auth/oauth_refresh.go @@ -1,11 +1,13 @@ package auth import ( + "context" "crypto/rand" - "encoding/base64" + "encoding/hex" "errors" "fmt" "net/http" + "sync" "time" "github.com/golang-jwt/jwt/v5" @@ -19,6 +21,10 @@ type oauthRefreshToken struct { Username string `json:"username"` RefreshToken string `json:"refresh_token"` Expiry time.Time `json:"expiry"` + + result *refreshResult + err error + mu sync.Mutex } type Session struct { @@ -27,6 +33,12 @@ type Session struct { Groups []string `json:"groups"` } +type refreshResult struct { + newSession Session + jwt string + jwtExpiry time.Time +} + type sessionClaims struct { Session jwt.RegisteredClaims @@ -34,11 +46,12 @@ type sessionClaims struct { type sessionID string -var oauthRefreshTokens jsonstore.MapStore[oauthRefreshToken] +var oauthRefreshTokens jsonstore.MapStore[*oauthRefreshToken] var ( defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month refreshBefore = 30 * time.Second + sessionInvalidateDelay = 3 * time.Second ) var ( @@ -50,7 +63,7 @@ const sessionTokenIssuer = "GoDoxy" func init() { if IsOIDCEnabled() { - oauthRefreshTokens = jsonstore.Store[oauthRefreshToken]("oauth_refresh_tokens") + oauthRefreshTokens = jsonstore.Store[*oauthRefreshToken]("oauth_refresh_tokens") } } @@ -61,7 +74,7 @@ func (token *oauthRefreshToken) expired() bool { func newSessionID() sessionID { b := make([]byte, 32) _, _ = rand.Read(b) - return sessionID(base64.StdEncoding.EncodeToString(b)) + return sessionID(hex.EncodeToString(b)) } func newSession(username string, groups []string) Session { @@ -72,26 +85,28 @@ func newSession(username string, groups []string) Session { } } -// getOnceOAuthRefreshToken returns the refresh token for the given session. +// getOAuthRefreshToken returns the refresh token for the given session. // // The token is removed from the store after retrieval. -func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) { +func getOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) { token, ok := oauthRefreshTokens.Load(string(claims.SessionID)) if !ok { return nil, false } - invalidateOAuthRefreshToken(claims.SessionID) + if token.expired() { + invalidateOAuthRefreshToken(claims.SessionID) return nil, false } + if claims.Username != token.Username { return nil, false } - return &token, true + return token, true } func storeOAuthRefreshToken(sessionID sessionID, username, token string) { - oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{ + oauthRefreshTokens.Store(string(sessionID), &oauthRefreshToken{ Username: username, RefreshToken: token, Expiry: time.Now().Add(defaultRefreshTokenExpiry), @@ -135,51 +150,75 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil } -func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request, sessionJWT string) error { +func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*refreshResult, error) { // verify the session cookie claims, valid, err := auth.parseSessionJWT(sessionJWT) if err != nil { - return fmt.Errorf("%w: %w", ErrInvalidSessionToken, err) + return nil, fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrInvalidSessionToken, err) } if !valid { - return ErrInvalidSessionToken + return nil, ErrInvalidSessionToken } // check if refresh is possible - refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session) + refreshToken, ok := getOAuthRefreshToken(&claims.Session) if !ok { - return errNoRefreshToken + return nil, errNoRefreshToken } if !auth.checkAllowed(claims.Username, claims.Groups) { - return ErrUserNotAllowed + return nil, ErrUserNotAllowed + } + + return auth.doRefreshToken(ctx, refreshToken, &claims.Session) +} + +func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*refreshResult, error) { + refreshToken.mu.Lock() + defer refreshToken.mu.Unlock() + + // already refreshed + // this must be called after refresh but before invalidate + if refreshToken.result != nil || refreshToken.err != nil { + return refreshToken.result, refreshToken.err } // this step refreshes the token // see https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.29.0:oauth2.go;l=313 - newToken, err := auth.oauthConfig.TokenSource(r.Context(), &oauth2.Token{ + newToken, err := auth.oauthConfig.TokenSource(ctx, &oauth2.Token{ RefreshToken: refreshToken.RefreshToken, }).Token() if err != nil { - return fmt.Errorf("%w: %w", ErrRefreshTokenFailure, err) + refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err) + return nil, refreshToken.err } - idTokenJWT, idToken, err := auth.getIdToken(r.Context(), newToken) + idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken) if err != nil { - return err + refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err) + return nil, refreshToken.err } + // in case there're multiple requests for the same session to refresh + // invalidate the token after a short delay + go func() { + <-time.After(sessionInvalidateDelay) + invalidateOAuthRefreshToken(claims.SessionID) + }() + sessionID := newSessionID() logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token") storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken) - // set new idToken and new sessionToken - auth.setIDTokenCookie(w, r, idTokenJWT, time.Until(idToken.Expiry)) - auth.setSessionTokenCookie(w, r, Session{ - SessionID: sessionID, - Username: claims.Username, - Groups: claims.Groups, - }) - return nil + refreshToken.result = &refreshResult{ + newSession: Session{ + SessionID: sessionID, + Username: claims.Username, + Groups: claims.Groups, + }, + jwt: idTokenJWT, + jwtExpiry: idToken.Expiry, + } + return refreshToken.result, nil } diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 36e4a1b..c1068a5 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -13,6 +13,7 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/utils" @@ -47,7 +48,12 @@ const ( OIDCLogoutPath = "/auth/logout" ) -var errMissingIDToken = errors.New("missing id_token field from oauth token") +var ( + errMissingIDToken = errors.New("missing id_token field from oauth token") + + ErrMissingOAuthToken = gperr.New("missing oauth token") + ErrInvalidOAuthToken = gperr.New("invalid oauth token") +) // generateState generates a random string for OIDC state. const oidcStateLength = 32 @@ -148,12 +154,19 @@ func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) { func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { // check for session token sessionToken, err := r.Cookie(CookieOauthSessionToken) - if err == nil { - err = auth.TryRefreshToken(w, r, sessionToken.Value) - if err != nil { - logging.Debug().Err(err).Msg("failed to refresh token") - auth.clearCookie(w, r) + if err == nil { // session token exists + result, err := auth.TryRefreshToken(r.Context(), sessionToken.Value) + // redirect back to where they requested + // when token refresh is ok + if err == nil { + auth.setIDTokenCookie(w, r, result.jwt, time.Until(result.jwtExpiry)) + auth.setSessionTokenCookie(w, r, result.newSession) + ProceedNext(w, r) + return } + // clear cookies then redirect to home + logging.Err(err).Msg("failed to refresh token") + auth.clearCookie(w, r) http.Redirect(w, r, "/", http.StatusFound) return } diff --git a/internal/auth/utils.go b/internal/auth/utils.go index ba0297e..f1c20d6 100644 --- a/internal/auth/utils.go +++ b/internal/auth/utils.go @@ -10,22 +10,21 @@ import ( ) var ( - ErrMissingOAuthToken = gperr.New("missing oauth token") ErrMissingSessionToken = gperr.New("missing session token") - ErrInvalidOAuthToken = gperr.New("invalid oauth token") ErrInvalidSessionToken = gperr.New("invalid session token") ErrUserNotAllowed = gperr.New("user not allowed") ) +func IsFrontend(r *http.Request) bool { + return r.Host == common.APIHTTPAddr +} + func requestHost(r *http.Request) string { // check if it's from backend - switch r.Host { - case common.APIHTTPAddr: - // use XFH + if IsFrontend(r) { return r.Header.Get("X-Forwarded-Host") - default: - return r.Host } + return r.Host } // cookieDomain returns the fully qualified domain name of the request host