fix(oidc): token not being refreshed when receiving simutaneous requests from the same session

This commit is contained in:
yusing 2025-04-28 11:19:57 +08:00
parent 27409abc24
commit c5fd21552e
4 changed files with 119 additions and 49 deletions

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"context"
"net/http" "net/http"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
@ -38,17 +39,35 @@ func IsOIDCEnabled() bool {
return common.OIDCIssuerURL != "" return common.OIDCIssuerURL != ""
} }
type nextHandler struct{}
var nextHandlerContextKey = nextHandler{}
func RequireAuth(next http.HandlerFunc) http.HandlerFunc { func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
if IsEnabled() { if !IsEnabled() {
return func(w http.ResponseWriter, r *http.Request) { return next
if err := defaultAuth.CheckToken(r); err != nil { }
gphttp.ClientError(w, err, http.StatusUnauthorized) return func(w http.ResponseWriter, r *http.Request) {
} else { if err := defaultAuth.CheckToken(r); err != nil {
next(w, r) if IsFrontend(r) {
} r = r.WithContext(context.WithValue(r.Context(), nextHandlerContextKey, next))
} defaultAuth.LoginHandler(w, r)
} else {
gphttp.ClientError(w, err, http.StatusUnauthorized)
}
return
}
next(w, r)
}
}
func ProceedNext(w http.ResponseWriter, r *http.Request) {
next, ok := r.Context().Value(nextHandlerContextKey).(http.HandlerFunc)
if ok {
next(w, r)
} else {
w.WriteHeader(http.StatusOK)
} }
return next
} }
func AuthCheckHandler(w http.ResponseWriter, r *http.Request) { func AuthCheckHandler(w http.ResponseWriter, r *http.Request) {

View file

@ -1,11 +1,13 @@
package auth package auth
import ( import (
"context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@ -19,6 +21,10 @@ type oauthRefreshToken struct {
Username string `json:"username"` Username string `json:"username"`
RefreshToken string `json:"refresh_token"` RefreshToken string `json:"refresh_token"`
Expiry time.Time `json:"expiry"` Expiry time.Time `json:"expiry"`
result *refreshResult
err error
mu sync.Mutex
} }
type Session struct { type Session struct {
@ -27,6 +33,12 @@ type Session struct {
Groups []string `json:"groups"` Groups []string `json:"groups"`
} }
type refreshResult struct {
newSession Session
jwt string
jwtExpiry time.Time
}
type sessionClaims struct { type sessionClaims struct {
Session Session
jwt.RegisteredClaims jwt.RegisteredClaims
@ -34,11 +46,12 @@ type sessionClaims struct {
type sessionID string type sessionID string
var oauthRefreshTokens jsonstore.MapStore[oauthRefreshToken] var oauthRefreshTokens jsonstore.MapStore[*oauthRefreshToken]
var ( var (
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
refreshBefore = 30 * time.Second refreshBefore = 30 * time.Second
sessionInvalidateDelay = 3 * time.Second
) )
var ( var (
@ -50,7 +63,7 @@ const sessionTokenIssuer = "GoDoxy"
func init() { func init() {
if IsOIDCEnabled() { if IsOIDCEnabled() {
oauthRefreshTokens = jsonstore.Store[oauthRefreshToken]("oauth_refresh_tokens") oauthRefreshTokens = jsonstore.Store[*oauthRefreshToken]("oauth_refresh_tokens")
} }
} }
@ -61,7 +74,7 @@ func (token *oauthRefreshToken) expired() bool {
func newSessionID() sessionID { func newSessionID() sessionID {
b := make([]byte, 32) b := make([]byte, 32)
_, _ = rand.Read(b) _, _ = rand.Read(b)
return sessionID(base64.StdEncoding.EncodeToString(b)) return sessionID(hex.EncodeToString(b))
} }
func newSession(username string, groups []string) Session { func newSession(username string, groups []string) Session {
@ -72,26 +85,28 @@ func newSession(username string, groups []string) Session {
} }
} }
// getOnceOAuthRefreshToken returns the refresh token for the given session. // getOAuthRefreshToken returns the refresh token for the given session.
// //
// The token is removed from the store after retrieval. // The token is removed from the store after retrieval.
func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) { func getOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
token, ok := oauthRefreshTokens.Load(string(claims.SessionID)) token, ok := oauthRefreshTokens.Load(string(claims.SessionID))
if !ok { if !ok {
return nil, false return nil, false
} }
invalidateOAuthRefreshToken(claims.SessionID)
if token.expired() { if token.expired() {
invalidateOAuthRefreshToken(claims.SessionID)
return nil, false return nil, false
} }
if claims.Username != token.Username { if claims.Username != token.Username {
return nil, false return nil, false
} }
return &token, true return token, true
} }
func storeOAuthRefreshToken(sessionID sessionID, username, token string) { func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{ oauthRefreshTokens.Store(string(sessionID), &oauthRefreshToken{
Username: username, Username: username,
RefreshToken: token, RefreshToken: token,
Expiry: time.Now().Add(defaultRefreshTokenExpiry), Expiry: time.Now().Add(defaultRefreshTokenExpiry),
@ -135,51 +150,75 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
} }
func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request, sessionJWT string) error { func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*refreshResult, error) {
// verify the session cookie // verify the session cookie
claims, valid, err := auth.parseSessionJWT(sessionJWT) claims, valid, err := auth.parseSessionJWT(sessionJWT)
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrInvalidSessionToken, err) return nil, fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrInvalidSessionToken, err)
} }
if !valid { if !valid {
return ErrInvalidSessionToken return nil, ErrInvalidSessionToken
} }
// check if refresh is possible // check if refresh is possible
refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session) refreshToken, ok := getOAuthRefreshToken(&claims.Session)
if !ok { if !ok {
return errNoRefreshToken return nil, errNoRefreshToken
} }
if !auth.checkAllowed(claims.Username, claims.Groups) { if !auth.checkAllowed(claims.Username, claims.Groups) {
return ErrUserNotAllowed return nil, ErrUserNotAllowed
}
return auth.doRefreshToken(ctx, refreshToken, &claims.Session)
}
func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*refreshResult, error) {
refreshToken.mu.Lock()
defer refreshToken.mu.Unlock()
// already refreshed
// this must be called after refresh but before invalidate
if refreshToken.result != nil || refreshToken.err != nil {
return refreshToken.result, refreshToken.err
} }
// this step refreshes the token // this step refreshes the token
// see https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.29.0:oauth2.go;l=313 // see https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.29.0:oauth2.go;l=313
newToken, err := auth.oauthConfig.TokenSource(r.Context(), &oauth2.Token{ newToken, err := auth.oauthConfig.TokenSource(ctx, &oauth2.Token{
RefreshToken: refreshToken.RefreshToken, RefreshToken: refreshToken.RefreshToken,
}).Token() }).Token()
if err != nil { if err != nil {
return fmt.Errorf("%w: %w", ErrRefreshTokenFailure, err) refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
return nil, refreshToken.err
} }
idTokenJWT, idToken, err := auth.getIdToken(r.Context(), newToken) idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken)
if err != nil { if err != nil {
return err refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
return nil, refreshToken.err
} }
// in case there're multiple requests for the same session to refresh
// invalidate the token after a short delay
go func() {
<-time.After(sessionInvalidateDelay)
invalidateOAuthRefreshToken(claims.SessionID)
}()
sessionID := newSessionID() sessionID := newSessionID()
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token") logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken) storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
// set new idToken and new sessionToken refreshToken.result = &refreshResult{
auth.setIDTokenCookie(w, r, idTokenJWT, time.Until(idToken.Expiry)) newSession: Session{
auth.setSessionTokenCookie(w, r, Session{ SessionID: sessionID,
SessionID: sessionID, Username: claims.Username,
Username: claims.Username, Groups: claims.Groups,
Groups: claims.Groups, },
}) jwt: idTokenJWT,
return nil jwtExpiry: idToken.Expiry,
}
return refreshToken.result, nil
} }

View file

@ -13,6 +13,7 @@ import (
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/net/gphttp" "github.com/yusing/go-proxy/internal/net/gphttp"
"github.com/yusing/go-proxy/internal/utils" "github.com/yusing/go-proxy/internal/utils"
@ -47,7 +48,12 @@ const (
OIDCLogoutPath = "/auth/logout" OIDCLogoutPath = "/auth/logout"
) )
var errMissingIDToken = errors.New("missing id_token field from oauth token") var (
errMissingIDToken = errors.New("missing id_token field from oauth token")
ErrMissingOAuthToken = gperr.New("missing oauth token")
ErrInvalidOAuthToken = gperr.New("invalid oauth token")
)
// generateState generates a random string for OIDC state. // generateState generates a random string for OIDC state.
const oidcStateLength = 32 const oidcStateLength = 32
@ -148,12 +154,19 @@ func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
// check for session token // check for session token
sessionToken, err := r.Cookie(CookieOauthSessionToken) sessionToken, err := r.Cookie(CookieOauthSessionToken)
if err == nil { if err == nil { // session token exists
err = auth.TryRefreshToken(w, r, sessionToken.Value) result, err := auth.TryRefreshToken(r.Context(), sessionToken.Value)
if err != nil { // redirect back to where they requested
logging.Debug().Err(err).Msg("failed to refresh token") // when token refresh is ok
auth.clearCookie(w, r) if err == nil {
auth.setIDTokenCookie(w, r, result.jwt, time.Until(result.jwtExpiry))
auth.setSessionTokenCookie(w, r, result.newSession)
ProceedNext(w, r)
return
} }
// clear cookies then redirect to home
logging.Err(err).Msg("failed to refresh token")
auth.clearCookie(w, r)
http.Redirect(w, r, "/", http.StatusFound) http.Redirect(w, r, "/", http.StatusFound)
return return
} }

View file

@ -10,22 +10,21 @@ import (
) )
var ( var (
ErrMissingOAuthToken = gperr.New("missing oauth token")
ErrMissingSessionToken = gperr.New("missing session token") ErrMissingSessionToken = gperr.New("missing session token")
ErrInvalidOAuthToken = gperr.New("invalid oauth token")
ErrInvalidSessionToken = gperr.New("invalid session token") ErrInvalidSessionToken = gperr.New("invalid session token")
ErrUserNotAllowed = gperr.New("user not allowed") ErrUserNotAllowed = gperr.New("user not allowed")
) )
func IsFrontend(r *http.Request) bool {
return r.Host == common.APIHTTPAddr
}
func requestHost(r *http.Request) string { func requestHost(r *http.Request) string {
// check if it's from backend // check if it's from backend
switch r.Host { if IsFrontend(r) {
case common.APIHTTPAddr:
// use XFH
return r.Header.Get("X-Forwarded-Host") return r.Header.Get("X-Forwarded-Host")
default:
return r.Host
} }
return r.Host
} }
// cookieDomain returns the fully qualified domain name of the request host // cookieDomain returns the fully qualified domain name of the request host