mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-22 20:24:03 +02:00
fix(oidc): token not being refreshed when receiving simutaneous requests from the same session
This commit is contained in:
parent
27409abc24
commit
c5fd21552e
4 changed files with 119 additions and 49 deletions
|
@ -1,6 +1,7 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
@ -38,17 +39,35 @@ func IsOIDCEnabled() bool {
|
||||||
return common.OIDCIssuerURL != ""
|
return common.OIDCIssuerURL != ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type nextHandler struct{}
|
||||||
|
|
||||||
|
var nextHandlerContextKey = nextHandler{}
|
||||||
|
|
||||||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
if IsEnabled() {
|
if !IsEnabled() {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return next
|
||||||
if err := defaultAuth.CheckToken(r); err != nil {
|
}
|
||||||
gphttp.ClientError(w, err, http.StatusUnauthorized)
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
} else {
|
if err := defaultAuth.CheckToken(r); err != nil {
|
||||||
next(w, r)
|
if IsFrontend(r) {
|
||||||
}
|
r = r.WithContext(context.WithValue(r.Context(), nextHandlerContextKey, next))
|
||||||
}
|
defaultAuth.LoginHandler(w, r)
|
||||||
|
} else {
|
||||||
|
gphttp.ClientError(w, err, http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ProceedNext(w http.ResponseWriter, r *http.Request) {
|
||||||
|
next, ok := r.Context().Value(nextHandlerContextKey).(http.HandlerFunc)
|
||||||
|
if ok {
|
||||||
|
next(w, r)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
}
|
}
|
||||||
return next
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func AuthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
func AuthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
@ -19,6 +21,10 @@ type oauthRefreshToken struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
RefreshToken string `json:"refresh_token"`
|
RefreshToken string `json:"refresh_token"`
|
||||||
Expiry time.Time `json:"expiry"`
|
Expiry time.Time `json:"expiry"`
|
||||||
|
|
||||||
|
result *refreshResult
|
||||||
|
err error
|
||||||
|
mu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
type Session struct {
|
type Session struct {
|
||||||
|
@ -27,6 +33,12 @@ type Session struct {
|
||||||
Groups []string `json:"groups"`
|
Groups []string `json:"groups"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type refreshResult struct {
|
||||||
|
newSession Session
|
||||||
|
jwt string
|
||||||
|
jwtExpiry time.Time
|
||||||
|
}
|
||||||
|
|
||||||
type sessionClaims struct {
|
type sessionClaims struct {
|
||||||
Session
|
Session
|
||||||
jwt.RegisteredClaims
|
jwt.RegisteredClaims
|
||||||
|
@ -34,11 +46,12 @@ type sessionClaims struct {
|
||||||
|
|
||||||
type sessionID string
|
type sessionID string
|
||||||
|
|
||||||
var oauthRefreshTokens jsonstore.MapStore[oauthRefreshToken]
|
var oauthRefreshTokens jsonstore.MapStore[*oauthRefreshToken]
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
|
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
|
||||||
refreshBefore = 30 * time.Second
|
refreshBefore = 30 * time.Second
|
||||||
|
sessionInvalidateDelay = 3 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -50,7 +63,7 @@ const sessionTokenIssuer = "GoDoxy"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if IsOIDCEnabled() {
|
if IsOIDCEnabled() {
|
||||||
oauthRefreshTokens = jsonstore.Store[oauthRefreshToken]("oauth_refresh_tokens")
|
oauthRefreshTokens = jsonstore.Store[*oauthRefreshToken]("oauth_refresh_tokens")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -61,7 +74,7 @@ func (token *oauthRefreshToken) expired() bool {
|
||||||
func newSessionID() sessionID {
|
func newSessionID() sessionID {
|
||||||
b := make([]byte, 32)
|
b := make([]byte, 32)
|
||||||
_, _ = rand.Read(b)
|
_, _ = rand.Read(b)
|
||||||
return sessionID(base64.StdEncoding.EncodeToString(b))
|
return sessionID(hex.EncodeToString(b))
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSession(username string, groups []string) Session {
|
func newSession(username string, groups []string) Session {
|
||||||
|
@ -72,26 +85,28 @@ func newSession(username string, groups []string) Session {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getOnceOAuthRefreshToken returns the refresh token for the given session.
|
// getOAuthRefreshToken returns the refresh token for the given session.
|
||||||
//
|
//
|
||||||
// The token is removed from the store after retrieval.
|
// The token is removed from the store after retrieval.
|
||||||
func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
|
func getOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
|
||||||
token, ok := oauthRefreshTokens.Load(string(claims.SessionID))
|
token, ok := oauthRefreshTokens.Load(string(claims.SessionID))
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
invalidateOAuthRefreshToken(claims.SessionID)
|
|
||||||
if token.expired() {
|
if token.expired() {
|
||||||
|
invalidateOAuthRefreshToken(claims.SessionID)
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if claims.Username != token.Username {
|
if claims.Username != token.Username {
|
||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
return &token, true
|
return token, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
|
func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
|
||||||
oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{
|
oauthRefreshTokens.Store(string(sessionID), &oauthRefreshToken{
|
||||||
Username: username,
|
Username: username,
|
||||||
RefreshToken: token,
|
RefreshToken: token,
|
||||||
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
|
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
|
||||||
|
@ -135,51 +150,75 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla
|
||||||
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
|
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request, sessionJWT string) error {
|
func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*refreshResult, error) {
|
||||||
// verify the session cookie
|
// verify the session cookie
|
||||||
claims, valid, err := auth.parseSessionJWT(sessionJWT)
|
claims, valid, err := auth.parseSessionJWT(sessionJWT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrInvalidSessionToken, err)
|
return nil, fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrInvalidSessionToken, err)
|
||||||
}
|
}
|
||||||
if !valid {
|
if !valid {
|
||||||
return ErrInvalidSessionToken
|
return nil, ErrInvalidSessionToken
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if refresh is possible
|
// check if refresh is possible
|
||||||
refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session)
|
refreshToken, ok := getOAuthRefreshToken(&claims.Session)
|
||||||
if !ok {
|
if !ok {
|
||||||
return errNoRefreshToken
|
return nil, errNoRefreshToken
|
||||||
}
|
}
|
||||||
|
|
||||||
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
||||||
return ErrUserNotAllowed
|
return nil, ErrUserNotAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
return auth.doRefreshToken(ctx, refreshToken, &claims.Session)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*refreshResult, error) {
|
||||||
|
refreshToken.mu.Lock()
|
||||||
|
defer refreshToken.mu.Unlock()
|
||||||
|
|
||||||
|
// already refreshed
|
||||||
|
// this must be called after refresh but before invalidate
|
||||||
|
if refreshToken.result != nil || refreshToken.err != nil {
|
||||||
|
return refreshToken.result, refreshToken.err
|
||||||
}
|
}
|
||||||
|
|
||||||
// this step refreshes the token
|
// this step refreshes the token
|
||||||
// see https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.29.0:oauth2.go;l=313
|
// see https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.29.0:oauth2.go;l=313
|
||||||
newToken, err := auth.oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
newToken, err := auth.oauthConfig.TokenSource(ctx, &oauth2.Token{
|
||||||
RefreshToken: refreshToken.RefreshToken,
|
RefreshToken: refreshToken.RefreshToken,
|
||||||
}).Token()
|
}).Token()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrRefreshTokenFailure, err)
|
refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
|
||||||
|
return nil, refreshToken.err
|
||||||
}
|
}
|
||||||
|
|
||||||
idTokenJWT, idToken, err := auth.getIdToken(r.Context(), newToken)
|
idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err)
|
||||||
|
return nil, refreshToken.err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// in case there're multiple requests for the same session to refresh
|
||||||
|
// invalidate the token after a short delay
|
||||||
|
go func() {
|
||||||
|
<-time.After(sessionInvalidateDelay)
|
||||||
|
invalidateOAuthRefreshToken(claims.SessionID)
|
||||||
|
}()
|
||||||
|
|
||||||
sessionID := newSessionID()
|
sessionID := newSessionID()
|
||||||
|
|
||||||
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
||||||
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
|
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
|
||||||
|
|
||||||
// set new idToken and new sessionToken
|
refreshToken.result = &refreshResult{
|
||||||
auth.setIDTokenCookie(w, r, idTokenJWT, time.Until(idToken.Expiry))
|
newSession: Session{
|
||||||
auth.setSessionTokenCookie(w, r, Session{
|
SessionID: sessionID,
|
||||||
SessionID: sessionID,
|
Username: claims.Username,
|
||||||
Username: claims.Username,
|
Groups: claims.Groups,
|
||||||
Groups: claims.Groups,
|
},
|
||||||
})
|
jwt: idTokenJWT,
|
||||||
return nil
|
jwtExpiry: idToken.Expiry,
|
||||||
|
}
|
||||||
|
return refreshToken.result, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,7 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
@ -47,7 +48,12 @@ const (
|
||||||
OIDCLogoutPath = "/auth/logout"
|
OIDCLogoutPath = "/auth/logout"
|
||||||
)
|
)
|
||||||
|
|
||||||
var errMissingIDToken = errors.New("missing id_token field from oauth token")
|
var (
|
||||||
|
errMissingIDToken = errors.New("missing id_token field from oauth token")
|
||||||
|
|
||||||
|
ErrMissingOAuthToken = gperr.New("missing oauth token")
|
||||||
|
ErrInvalidOAuthToken = gperr.New("invalid oauth token")
|
||||||
|
)
|
||||||
|
|
||||||
// generateState generates a random string for OIDC state.
|
// generateState generates a random string for OIDC state.
|
||||||
const oidcStateLength = 32
|
const oidcStateLength = 32
|
||||||
|
@ -148,12 +154,19 @@ func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
// check for session token
|
// check for session token
|
||||||
sessionToken, err := r.Cookie(CookieOauthSessionToken)
|
sessionToken, err := r.Cookie(CookieOauthSessionToken)
|
||||||
if err == nil {
|
if err == nil { // session token exists
|
||||||
err = auth.TryRefreshToken(w, r, sessionToken.Value)
|
result, err := auth.TryRefreshToken(r.Context(), sessionToken.Value)
|
||||||
if err != nil {
|
// redirect back to where they requested
|
||||||
logging.Debug().Err(err).Msg("failed to refresh token")
|
// when token refresh is ok
|
||||||
auth.clearCookie(w, r)
|
if err == nil {
|
||||||
|
auth.setIDTokenCookie(w, r, result.jwt, time.Until(result.jwtExpiry))
|
||||||
|
auth.setSessionTokenCookie(w, r, result.newSession)
|
||||||
|
ProceedNext(w, r)
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
// clear cookies then redirect to home
|
||||||
|
logging.Err(err).Msg("failed to refresh token")
|
||||||
|
auth.clearCookie(w, r)
|
||||||
http.Redirect(w, r, "/", http.StatusFound)
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,22 +10,21 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissingOAuthToken = gperr.New("missing oauth token")
|
|
||||||
ErrMissingSessionToken = gperr.New("missing session token")
|
ErrMissingSessionToken = gperr.New("missing session token")
|
||||||
ErrInvalidOAuthToken = gperr.New("invalid oauth token")
|
|
||||||
ErrInvalidSessionToken = gperr.New("invalid session token")
|
ErrInvalidSessionToken = gperr.New("invalid session token")
|
||||||
ErrUserNotAllowed = gperr.New("user not allowed")
|
ErrUserNotAllowed = gperr.New("user not allowed")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func IsFrontend(r *http.Request) bool {
|
||||||
|
return r.Host == common.APIHTTPAddr
|
||||||
|
}
|
||||||
|
|
||||||
func requestHost(r *http.Request) string {
|
func requestHost(r *http.Request) string {
|
||||||
// check if it's from backend
|
// check if it's from backend
|
||||||
switch r.Host {
|
if IsFrontend(r) {
|
||||||
case common.APIHTTPAddr:
|
|
||||||
// use XFH
|
|
||||||
return r.Header.Get("X-Forwarded-Host")
|
return r.Header.Get("X-Forwarded-Host")
|
||||||
default:
|
|
||||||
return r.Host
|
|
||||||
}
|
}
|
||||||
|
return r.Host
|
||||||
}
|
}
|
||||||
|
|
||||||
// cookieDomain returns the fully qualified domain name of the request host
|
// cookieDomain returns the fully qualified domain name of the request host
|
||||||
|
|
Loading…
Add table
Reference in a new issue