package auth import ( "context" "crypto/rand" "encoding/hex" "errors" "fmt" "net/http" "sync" "time" "github.com/golang-jwt/jwt/v5" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/jsonstore" "github.com/yusing/go-proxy/internal/logging" "golang.org/x/oauth2" ) type oauthRefreshToken struct { Username string `json:"username"` RefreshToken string `json:"refresh_token"` Expiry time.Time `json:"expiry"` result *RefreshResult err error mu sync.Mutex } type Session struct { SessionID sessionID `json:"session_id"` Username string `json:"username"` Groups []string `json:"groups"` } type RefreshResult struct { newSession Session jwt string jwtExpiry time.Time } type sessionClaims struct { Session jwt.RegisteredClaims } type sessionID string var oauthRefreshTokens jsonstore.MapStore[*oauthRefreshToken] var ( defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month sessionInvalidateDelay = 3 * time.Second ) var ( errNoRefreshToken = errors.New("no refresh token") ErrRefreshTokenFailure = errors.New("failed to refresh token") ) const sessionTokenIssuer = "GoDoxy" func init() { if IsOIDCEnabled() { oauthRefreshTokens = jsonstore.Store[*oauthRefreshToken]("oauth_refresh_tokens") } } func (token *oauthRefreshToken) expired() bool { return time.Now().After(token.Expiry) } func newSessionID() sessionID { b := make([]byte, 32) _, _ = rand.Read(b) return sessionID(hex.EncodeToString(b)) } func newSession(username string, groups []string) Session { return Session{ SessionID: newSessionID(), Username: username, Groups: groups, } } // getOAuthRefreshToken returns the refresh token for the given session. func getOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) { token, ok := oauthRefreshTokens.Load(string(claims.SessionID)) if !ok { return nil, false } if token.expired() { invalidateOAuthRefreshToken(claims.SessionID) return nil, false } if claims.Username != token.Username { return nil, false } return token, true } func storeOAuthRefreshToken(sessionID sessionID, username, token string) { oauthRefreshTokens.Store(string(sessionID), &oauthRefreshToken{ Username: username, RefreshToken: token, Expiry: time.Now().Add(defaultRefreshTokenExpiry), }) logging.Debug().Str("username", username).Msg("stored oauth refresh token") } func invalidateOAuthRefreshToken(sessionID sessionID) { logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token") oauthRefreshTokens.Delete(string(sessionID)) } func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.Request, session Session) { claims := &sessionClaims{ Session: session, RegisteredClaims: jwt.RegisteredClaims{ Issuer: sessionTokenIssuer, ExpiresAt: jwt.NewNumericDate(time.Now().Add(common.APIJWTTokenTTL)), }, } jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) signed, err := jwtToken.SignedString(common.APIJWTSecret) if err != nil { logging.Err(err).Msg("failed to sign session token") return } SetTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL) } func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) { claims = &sessionClaims{} sessionToken, err := jwt.ParseWithClaims(sessionJWT, claims, func(t *jwt.Token) (interface{}, error) { if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) } return common.APIJWTSecret, nil }) if err != nil { return nil, false, err } return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil } func (auth *OIDCProvider) TryRefreshToken(ctx context.Context, sessionJWT string) (*RefreshResult, error) { // verify the session cookie claims, valid, err := auth.parseSessionJWT(sessionJWT) if err != nil { return nil, fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrInvalidSessionToken, err) } if !valid { return nil, ErrInvalidSessionToken } // check if refresh is possible refreshToken, ok := getOAuthRefreshToken(&claims.Session) if !ok { return nil, errNoRefreshToken } if !auth.checkAllowed(claims.Username, claims.Groups) { return nil, ErrUserNotAllowed } return auth.doRefreshToken(ctx, refreshToken, &claims.Session) } func (auth *OIDCProvider) doRefreshToken(ctx context.Context, refreshToken *oauthRefreshToken, claims *Session) (*RefreshResult, error) { refreshToken.mu.Lock() defer refreshToken.mu.Unlock() // already refreshed // this must be called after refresh but before invalidate if refreshToken.result != nil || refreshToken.err != nil { return refreshToken.result, refreshToken.err } // this step refreshes the token // see https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.29.0:oauth2.go;l=313 newToken, err := auth.oauthConfig.TokenSource(ctx, &oauth2.Token{ RefreshToken: refreshToken.RefreshToken, }).Token() if err != nil { refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err) return nil, refreshToken.err } idTokenJWT, idToken, err := auth.getIdToken(ctx, newToken) if err != nil { refreshToken.err = fmt.Errorf("session: %s - %w: %w", claims.SessionID, ErrRefreshTokenFailure, err) return nil, refreshToken.err } // in case there're multiple requests for the same session to refresh // invalidate the token after a short delay go func() { <-time.After(sessionInvalidateDelay) invalidateOAuthRefreshToken(claims.SessionID) }() sessionID := newSessionID() logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token") storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken) refreshToken.result = &RefreshResult{ newSession: Session{ SessionID: sessionID, Username: claims.Username, Groups: claims.Groups, }, jwt: idTokenJWT, jwtExpiry: idToken.Expiry, } return refreshToken.result, nil }