From b815c6fd6938f1f022ade2f6fc5f3c5be0da840c Mon Sep 17 00:00:00 2001 From: yusing Date: Wed, 23 Apr 2025 17:50:22 +0800 Subject: [PATCH] feat(oidc): support token refreshing via offline_access scope - refactored code - moved api/v1/auth to auth/ - security enhancement - env example update - default jwt ttl changed to 24 hours --- .env.example | 17 +- cmd/main.go | 4 +- internal/api/handler.go | 2 +- internal/api/v1/auth/oidc.go | 272 ----------------- internal/{api/v1 => }/auth/auth.go | 0 internal/{api/v1 => }/auth/block_page.go | 0 internal/{api/v1 => }/auth/block_page.html | 0 internal/auth/oauth_refresh.go | 181 ++++++++++++ internal/auth/oidc.go | 310 ++++++++++++++++++++ internal/{api/v1 => }/auth/oidc_test.go | 4 +- internal/{api/v1 => }/auth/provider.go | 5 +- internal/{api/v1 => }/auth/userpass.go | 0 internal/{api/v1 => }/auth/userpass_test.go | 0 internal/{api/v1 => }/auth/utils.go | 11 - internal/common/constants.go | 2 + internal/common/crypto.go | 4 +- internal/common/env.go | 2 +- internal/jsonstore/internal.go | 63 ++++ internal/jsonstore/jsonstore.go | 47 +++ internal/jsonstore/jsonstore_test.go | 32 ++ internal/net/gphttp/middleware/oidc.go | 22 +- 21 files changed, 668 insertions(+), 310 deletions(-) delete mode 100644 internal/api/v1/auth/oidc.go rename internal/{api/v1 => }/auth/auth.go (100%) rename internal/{api/v1 => }/auth/block_page.go (100%) rename internal/{api/v1 => }/auth/block_page.html (100%) create mode 100644 internal/auth/oauth_refresh.go create mode 100644 internal/auth/oidc.go rename internal/{api/v1 => }/auth/oidc_test.go (99%) rename internal/{api/v1 => }/auth/provider.go (61%) rename internal/{api/v1 => }/auth/userpass.go (100%) rename internal/{api/v1 => }/auth/userpass_test.go (100%) rename internal/{api/v1 => }/auth/utils.go (85%) create mode 100644 internal/jsonstore/internal.go create mode 100644 internal/jsonstore/jsonstore.go create mode 100644 internal/jsonstore/jsonstore_test.go diff --git a/.env.example b/.env.example index 50d8ae8..770dc99 100644 --- a/.env.example +++ b/.env.example @@ -1,23 +1,26 @@ # set timezone to get correct log timestamp TZ=ETC/UTC +# API JWT Configuration (common) +# generate secret with `openssl rand -base64 32` +GODOXY_API_JWT_SECRET= +# the JWT token time-to-live +# leave empty to use default (24 hours) +# format: https://pkg.go.dev/time#Duration +GODOXY_API_JWT_TOKEN_TTL= + # API/WebUI user password login credentials (optional) # These fields are not required for OIDC authentication GODOXY_API_USER=admin GODOXY_API_PASSWORD=password -# generate secret with `openssl rand -base64 32` -GODOXY_API_JWT_SECRET= -# the JWT token time-to-live -GODOXY_API_JWT_TOKEN_TTL=1h # OIDC Configuration (optional) # Uncomment and configure these values to enable OIDC authentication. +# For `GODOXY_OIDC_SCOPES` you may also include `offline_access` if your Idp supports it (e.g. Authentik) +# # GODOXY_OIDC_ISSUER_URL=https://accounts.google.com # GODOXY_OIDC_CLIENT_ID=your-client-id # GODOXY_OIDC_CLIENT_SECRET=your-client-secret -# Keep /api/auth/callback as the redirect URL, change the domain to match your setup. -# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback -# Comma-separated list of scopes # GODOXY_OIDC_SCOPES=openid, profile, email # # User definitions: Uncomment and configure these values to restrict access to specific users or groups. diff --git a/cmd/main.go b/cmd/main.go index d93179a..7768711 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -7,13 +7,14 @@ import ( "sync" "github.com/yusing/go-proxy/internal" - "github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/favicon" "github.com/yusing/go-proxy/internal/api/v1/query" + "github.com/yusing/go-proxy/internal/auth" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/homepage" + "github.com/yusing/go-proxy/internal/jsonstore" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/metrics/systeminfo" @@ -79,6 +80,7 @@ func main() { logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion()) logging.Trace().Msg("trace enabled") parallel( + jsonstore.Initialize, internal.InitIconListCache, homepage.InitOverridesConfig, favicon.InitIconCache, diff --git a/internal/api/handler.go b/internal/api/handler.go index 178fb3d..70f12ea 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -6,10 +6,10 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" v1 "github.com/yusing/go-proxy/internal/api/v1" - "github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/certapi" "github.com/yusing/go-proxy/internal/api/v1/dockerapi" "github.com/yusing/go-proxy/internal/api/v1/favicon" + "github.com/yusing/go-proxy/internal/auth" "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" "github.com/yusing/go-proxy/internal/logging" diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go deleted file mode 100644 index 7ef9b74..0000000 --- a/internal/api/v1/auth/oidc.go +++ /dev/null @@ -1,272 +0,0 @@ -package auth - -import ( - "context" - "errors" - "fmt" - "net/http" - "net/url" - "slices" - "time" - - "github.com/coreos/go-oidc/v3/oidc" - "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/net/gphttp" - "github.com/yusing/go-proxy/internal/utils" - "github.com/yusing/go-proxy/internal/utils/strutils" - "golang.org/x/oauth2" -) - -type ( - OIDCProvider struct { - oauthConfig *oauth2.Config - oidcProvider *oidc.Provider - oidcVerifier *oidc.IDTokenVerifier - oidcEndSessionURL *url.URL - allowedUsers []string - allowedGroups []string - } - - providerJSON struct { - oidc.ProviderConfig - EndSessionURL string `json:"end_session_endpoint"` - } -) - -const CookieOauthState = "godoxy_oidc_state" - -const ( - OIDCAuthCallbackPath = "/auth/callback" - OIDCPostAuthPath = "/auth/postauth" - OIDCLogoutPath = "/auth/logout" -) - -var ( - ErrMissingState = errors.New("missing state cookie") - ErrInvalidState = errors.New("invalid oauth state") -) - -func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { - if len(allowedUsers)+len(allowedGroups) == 0 { - return nil, errors.New("OIDC users, groups, or both must not be empty") - } - provider, err := oidc.NewProvider(context.Background(), issuerURL) - if err != nil { - return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) - } - - endSessionURL, err := url.Parse(provider.EndSessionEndpoint()) - if err != nil && provider.EndSessionEndpoint() != "" { - // non critical, just warn - logging.Warn(). - Str("issuer", issuerURL). - Err(err). - Msg("failed to parse end session URL") - } - - return &OIDCProvider{ - oauthConfig: &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURL: "", - Endpoint: provider.Endpoint(), - Scopes: strutils.CommaSeperatedList(common.OIDCScopes), - }, - oidcProvider: provider, - oidcVerifier: provider.Verifier(&oidc.Config{ - ClientID: clientID, - }), - oidcEndSessionURL: endSessionURL, - allowedUsers: allowedUsers, - allowedGroups: allowedGroups, - }, nil -} - -// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables. -func NewOIDCProviderFromEnv() (*OIDCProvider, error) { - return NewOIDCProvider( - common.OIDCIssuerURL, - common.OIDCClientID, - common.OIDCClientSecret, - common.OIDCAllowedUsers, - common.OIDCAllowedGroups, - ) -} - -func (auth *OIDCProvider) TokenCookieName() string { - return "godoxy_oidc_token" -} - -func (auth *OIDCProvider) SetAllowedUsers(users []string) { - auth.allowedUsers = users -} - -func (auth *OIDCProvider) SetAllowedGroups(groups []string) { - auth.allowedGroups = groups -} - -func (auth *OIDCProvider) getVerifyStateCookie(r *http.Request) (string, error) { - state, err := r.Cookie(CookieOauthState) - if err != nil { - return "", ErrMissingState - } - if r.URL.Query().Get("state") != state.Value { - return "", ErrInvalidState - } - return state.Value, nil -} - -func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption { - return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath) -} - -func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) { - switch r.Method { - case http.MethodHead: - w.WriteHeader(http.StatusOK) - return - case http.MethodGet: - break - default: - gphttp.Forbidden(w, "method not allowed") - return - } - - switch r.URL.Path { - case OIDCAuthCallbackPath: - state := generateState() - http.SetCookie(w, &http.Cookie{ - Name: CookieOauthState, - Value: state, - MaxAge: 300, - HttpOnly: true, - SameSite: http.SameSiteLaxMode, - Secure: common.APIJWTSecure, - Path: "/", - }) - // redirect user to Idp - http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusTemporaryRedirect) - case OIDCPostAuthPath: - auth.PostAuthCallbackHandler(w, r) - default: - auth.LogoutHandler(w, r) - } -} - -func (auth *OIDCProvider) CheckToken(r *http.Request) error { - token, err := r.Cookie(auth.TokenCookieName()) - if err != nil { - return ErrMissingToken - } - - // checks for Expiry, Audience == ClientID, Issuer, etc. - idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value) - if err != nil { - return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err) - } - - if len(idToken.Audience) == 0 { - return ErrInvalidToken - } - - var claims struct { - Email string `json:"email"` - Username string `json:"preferred_username"` - Groups []string `json:"groups"` - } - if err := idToken.Claims(&claims); err != nil { - return fmt.Errorf("failed to parse claims: %w", err) - } - - // Logical AND between allowed users and groups. - allowedUser := slices.Contains(auth.allowedUsers, claims.Username) - allowedGroup := len(utils.Intersect(claims.Groups, auth.allowedGroups)) > 0 - if !allowedUser && !allowedGroup { - return ErrUserNotAllowed - } - return nil -} - -func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) -} - -func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http.Request) { - // For testing purposes, skip provider verification - if common.IsTest { - auth.handleTestCallback(w, r) - return - } - - _, err := auth.getVerifyStateCookie(r) - if err != nil { - gphttp.BadRequest(w, err.Error()) - return - } - - code := r.URL.Query().Get("code") - oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r)) - if err != nil { - gphttp.ServerError(w, r, fmt.Errorf("failed to exchange token: %w", err)) - return - } - - rawIDToken, ok := oauth2Token.Extra("id_token").(string) - if !ok { - gphttp.BadRequest(w, "missing id_token") - return - } - - idToken, err := auth.oidcVerifier.Verify(r.Context(), rawIDToken) - if err != nil { - gphttp.ServerError(w, r, fmt.Errorf("failed to verify ID token: %w", err)) - return - } - - setTokenCookie(w, r, auth.TokenCookieName(), rawIDToken, time.Until(idToken.Expiry)) - - // Redirect to home page - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) -} - -func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) { - if auth.oidcEndSessionURL == nil { - clearTokenCookie(w, r, auth.TokenCookieName()) - http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusTemporaryRedirect) - return - } - - token, err := r.Cookie(auth.TokenCookieName()) - if err == nil { - query := auth.oidcEndSessionURL.Query() - query.Add("id_token_hint", token.Value) - - logoutURL := *auth.oidcEndSessionURL - logoutURL.RawQuery = query.Encode() - - clearTokenCookie(w, r, auth.TokenCookieName()) - http.Redirect(w, r, logoutURL.String(), http.StatusFound) - } - - http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusTemporaryRedirect) -} - -// handleTestCallback handles OIDC callback in test environment. -func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) { - state, err := r.Cookie(CookieOauthState) - if err != nil { - gphttp.BadRequest(w, "missing state cookie") - return - } - - if r.URL.Query().Get("state") != state.Value { - gphttp.BadRequest(w, "invalid oauth state") - return - } - - // Create test JWT token - setTokenCookie(w, r, auth.TokenCookieName(), "test", time.Hour) - - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) -} diff --git a/internal/api/v1/auth/auth.go b/internal/auth/auth.go similarity index 100% rename from internal/api/v1/auth/auth.go rename to internal/auth/auth.go diff --git a/internal/api/v1/auth/block_page.go b/internal/auth/block_page.go similarity index 100% rename from internal/api/v1/auth/block_page.go rename to internal/auth/block_page.go diff --git a/internal/api/v1/auth/block_page.html b/internal/auth/block_page.html similarity index 100% rename from internal/api/v1/auth/block_page.html rename to internal/auth/block_page.html diff --git a/internal/auth/oauth_refresh.go b/internal/auth/oauth_refresh.go new file mode 100644 index 0000000..6fce981 --- /dev/null +++ b/internal/auth/oauth_refresh.go @@ -0,0 +1,181 @@ +package auth + +import ( + "crypto/rand" + "encoding/base64" + "fmt" + "net/http" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/jsonstore" + "github.com/yusing/go-proxy/internal/logging" + "golang.org/x/oauth2" +) + +type oauthRefreshToken struct { + Username string `json:"username"` + RefreshToken string `json:"refresh_token"` + Expiry time.Time `json:"expiry"` +} + +type Session struct { + SessionID sessionID `json:"session_id"` + Username string `json:"username"` + Groups []string `json:"groups"` +} + +type sessionClaims struct { + Session + jwt.RegisteredClaims +} + +type sessionID string + +var oauthRefreshTokens jsonstore.JSONStore[oauthRefreshToken] + +var ( + defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month + refreshBefore = 30 * time.Second +) + +const sessionTokenIssuer = "GoDoxy" + +func init() { + if IsOIDCEnabled() { + oauthRefreshTokens = jsonstore.NewStore[oauthRefreshToken]("oauth_refresh_tokens") + } +} + +func (token *oauthRefreshToken) expired() bool { + return time.Now().After(token.Expiry) +} + +func newSessionID() sessionID { + b := make([]byte, 32) + _, _ = rand.Read(b) + return sessionID(base64.StdEncoding.EncodeToString(b)) +} + +func newSession(username string, groups []string) Session { + return Session{ + SessionID: newSessionID(), + Username: username, + Groups: groups, + } +} + +func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) { + token, ok := oauthRefreshTokens.Load(string(claims.SessionID)) + if !ok { + return nil, false + } + invalidateOAuthRefreshToken(claims.SessionID) + if token.expired() { + return nil, false + } + if claims.Username != token.Username { + return nil, false + } + return &token, true +} + +func storeOAuthRefreshToken(sessionID sessionID, username, token string) { + logging.Debug().Str("username", username).Msg("setting oauth refresh token") + oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{ + Username: username, + RefreshToken: token, + Expiry: time.Now().Add(defaultRefreshTokenExpiry), + }) +} + +func invalidateOAuthRefreshToken(sessionID sessionID) { + oauthRefreshTokens.Delete(string(sessionID)) +} + +func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.Request, session Session) { + claims := &sessionClaims{ + Session: session, + RegisteredClaims: jwt.RegisteredClaims{ + Issuer: sessionTokenIssuer, + ExpiresAt: jwt.NewNumericDate(time.Now().Add(common.APIJWTTokenTTL)), + }, + } + jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims) + signed, err := jwtToken.SignedString(common.APIJWTSecret) + if err != nil { + logging.Err(err).Msg("failed to sign session token") + return + } + setTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL) +} + +func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) { + claims = &sessionClaims{} + sessionToken, err := jwt.ParseWithClaims(sessionJWT, claims, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return common.APIJWTSecret, nil + }) + if err != nil { + return nil, false, err + } + return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil +} + +func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request) error { + // check for session token + sessionCookie, err := r.Cookie(CookieOauthSessionToken) + if err != nil { + return ErrMissingToken + } + + // verify the session cookie + claims, valid, err := auth.parseSessionJWT(sessionCookie.Value) + if err != nil { + return fmt.Errorf("%w: %w", ErrInvalidToken, err) + } + if !valid { + return ErrInvalidToken + } + + // check if refresh is possible + refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session) + if !ok { + return ErrMissingToken + } + + if !auth.checkAllowed(claims.Username, claims.Groups) { + return ErrUserNotAllowed + } + + // this step refreshes the token + // see https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.29.0:oauth2.go;l=313 + newToken, err := auth.oauthConfig.TokenSource(r.Context(), &oauth2.Token{ + RefreshToken: refreshToken.RefreshToken, + }).Token() + if err != nil { + return fmt.Errorf("%w: %w", ErrRefreshTokenFailure, err) + } + + idTokenJWT, idToken, err := auth.getIdToken(r.Context(), newToken) + if err != nil { + return err + } + + sessionID := newSessionID() + + logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token") + storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken) + + // set new idToken and new sessionToken + auth.setIDTokenCookie(w, r, idTokenJWT, time.Until(idToken.Expiry)) + auth.setSessionTokenCookie(w, r, Session{ + SessionID: sessionID, + Username: claims.Username, + Groups: claims.Groups, + }) + return nil +} diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go new file mode 100644 index 0000000..0233c73 --- /dev/null +++ b/internal/auth/oidc.go @@ -0,0 +1,310 @@ +package auth + +import ( + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "net/http" + "net/url" + "slices" + "time" + + "github.com/coreos/go-oidc/v3/oidc" + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/net/gphttp" + "github.com/yusing/go-proxy/internal/utils" + "github.com/yusing/go-proxy/internal/utils/strutils" + "golang.org/x/oauth2" +) + +type ( + OIDCProvider struct { + oauthConfig *oauth2.Config + oidcProvider *oidc.Provider + oidcVerifier *oidc.IDTokenVerifier + endSessionURL *url.URL + allowedUsers []string + allowedGroups []string + } + + IDTokenClaims struct { + Username string `json:"preferred_username"` + Groups []string `json:"groups"` + } +) + +const ( + CookieOauthState = "godoxy_oidc_state" + CookieOauthSessionID = "godoxy_session_id" + CookieOauthToken = "godoxy_token" + CookieOauthSessionToken = "godoxy_session_token" +) + +const ( + OIDCAuthCallbackPath = "/auth/callback" + OIDCPostAuthPath = "/auth/postauth" + OIDCLogoutPath = "/auth/logout" +) + +var ( + ErrMissingIDToken = errors.New("missing id_token") + ErrRefreshTokenFailure = errors.New("failed to refresh token") +) + +// generateState generates a random string for OIDC state. +const oidcStateLength = 32 + +func generateState() string { + b := make([]byte, oidcStateLength) + _, _ = rand.Read(b) + return base64.URLEncoding.EncodeToString(b)[:oidcStateLength] +} + +func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) { + if len(allowedUsers)+len(allowedGroups) == 0 { + return nil, errors.New("OIDC users, groups, or both must not be empty") + } + provider, err := oidc.NewProvider(context.Background(), issuerURL) + if err != nil { + return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) + } + + endSessionURL, err := url.Parse(provider.EndSessionEndpoint()) + if err != nil && provider.EndSessionEndpoint() != "" { + // non critical, just warn + logging.Warn(). + Str("issuer", issuerURL). + Err(err). + Msg("failed to parse end session URL") + } + + return &OIDCProvider{ + oauthConfig: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: "", + Endpoint: provider.Endpoint(), + Scopes: strutils.CommaSeperatedList(common.OIDCScopes), + }, + oidcProvider: provider, + oidcVerifier: provider.Verifier(&oidc.Config{ + ClientID: clientID, + }), + endSessionURL: endSessionURL, + allowedUsers: allowedUsers, + allowedGroups: allowedGroups, + }, nil +} + +// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables. +func NewOIDCProviderFromEnv() (*OIDCProvider, error) { + return NewOIDCProvider( + common.OIDCIssuerURL, + common.OIDCClientID, + common.OIDCClientSecret, + common.OIDCAllowedUsers, + common.OIDCAllowedGroups, + ) +} + +func (auth *OIDCProvider) SetAllowedUsers(users []string) { + auth.allowedUsers = users +} + +func (auth *OIDCProvider) SetAllowedGroups(groups []string) { + auth.allowedGroups = groups +} + +func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption { + return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath) +} + +func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) { + idTokenJWT, ok := oauthToken.Extra("id_token").(string) + if !ok { + return "", nil, ErrMissingIDToken + } + idToken, err := auth.oidcVerifier.Verify(ctx, idTokenJWT) + if err != nil { + return "", nil, fmt.Errorf("failed to verify ID token: %w", err) + } + return idTokenJWT, idToken, nil +} + +func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) { + // check for session token + _, err := r.Cookie(CookieOauthSessionToken) + if err == nil { + err := auth.TryRefreshToken(w, r) + if err != nil { + logging.Debug().Err(err).Msg("failed to refresh token") + auth.LogoutHandler(w, r) + } else { + http.Redirect(w, r, "/", http.StatusFound) + } + return + } + + switch r.URL.Path { + case OIDCAuthCallbackPath: + state := generateState() + setTokenCookie(w, r, CookieOauthState, state, 300*time.Second) + // redirect user to Idp + http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound) + case OIDCPostAuthPath: + auth.PostAuthCallbackHandler(w, r) + case OIDCLogoutPath: + auth.LogoutHandler(w, r) + default: + http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusFound) + } +} + +func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) { + var claim IDTokenClaims + if err := idToken.Claims(&claim); err != nil { + return nil, fmt.Errorf("failed to parse claims: %w", err) + } + if claim.Username == "" { + return nil, fmt.Errorf("missing username in ID token") + } + return &claim, nil +} + +func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool { + userAllowed := slices.Contains(auth.allowedUsers, user) + if !userAllowed { + return false + } + if len(auth.allowedGroups) == 0 { + return true + } + return len(utils.Intersect(groups, auth.allowedGroups)) > 0 +} + +func (auth *OIDCProvider) CheckToken(r *http.Request) error { + tokenCookie, err := r.Cookie(CookieOauthToken) + if err != nil { + return ErrMissingToken + } + + idToken, err := auth.oidcVerifier.Verify(r.Context(), tokenCookie.Value) + if err != nil { + return fmt.Errorf("%w: %w", ErrInvalidToken, err) + } + + claims, err := parseClaims(idToken) + if err != nil { + return fmt.Errorf("%w: %w", ErrInvalidToken, err) + } + + if !auth.checkAllowed(claims.Username, claims.Groups) { + return ErrUserNotAllowed + } + return nil +} + +func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http.Request) { + // For testing purposes, skip provider verification + if common.IsTest { + auth.handleTestCallback(w, r) + return + } + + // verify state + state, err := r.Cookie(CookieOauthState) + if err != nil { + gphttp.BadRequest(w, "missing state cookie") + return + } + if r.URL.Query().Get("state") != state.Value { + gphttp.BadRequest(w, "invalid oauth state") + return + } + + code := r.URL.Query().Get("code") + oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r)) + if err != nil { + gphttp.ServerError(w, r, fmt.Errorf("failed to exchange token: %w", err)) + return + } + + idTokenJWT, idToken, err := auth.getIdToken(r.Context(), oauth2Token) + if err != nil { + gphttp.ServerError(w, r, err) + return + } + + if oauth2Token.RefreshToken != "" { + claims, err := parseClaims(idToken) + if err != nil { + gphttp.ServerError(w, r, err) + return + } + session := newSession(claims.Username, claims.Groups) + storeOAuthRefreshToken(session.SessionID, claims.Username, oauth2Token.RefreshToken) + auth.setSessionTokenCookie(w, r, session) + } + auth.setIDTokenCookie(w, r, idTokenJWT, time.Until(idToken.Expiry)) + + // Redirect to home page + http.Redirect(w, r, "/", http.StatusFound) +} + +func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) { + oauthToken, _ := r.Cookie(CookieOauthToken) + sessionToken, _ := r.Cookie(CookieOauthSessionToken) + auth.clearCookie(w, r) + + if sessionToken != nil { + claims, _, err := auth.parseSessionJWT(sessionToken.Value) + if err == nil { + invalidateOAuthRefreshToken(claims.SessionID) + } + } + + url := "/" + if auth.endSessionURL != nil && oauthToken != nil { + query := auth.endSessionURL.Query() + query.Set("id_token_hint", oauthToken.Value) + + clone := *auth.endSessionURL + clone.RawQuery = query.Encode() + url = clone.String() + } else if auth.endSessionURL != nil { + url = auth.endSessionURL.String() + } + + http.Redirect(w, r, url, http.StatusFound) +} + +func (auth *OIDCProvider) setIDTokenCookie(w http.ResponseWriter, r *http.Request, jwt string, ttl time.Duration) { + setTokenCookie(w, r, CookieOauthToken, jwt, ttl) +} + +func (auth *OIDCProvider) clearCookie(w http.ResponseWriter, r *http.Request) { + clearTokenCookie(w, r, CookieOauthToken) + clearTokenCookie(w, r, CookieOauthSessionToken) +} + +// handleTestCallback handles OIDC callback in test environment. +func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) { + state, err := r.Cookie(CookieOauthState) + if err != nil { + gphttp.BadRequest(w, "missing state cookie") + return + } + + if r.URL.Query().Get("state") != state.Value { + gphttp.BadRequest(w, "invalid oauth state") + return + } + + // Create test JWT token + setTokenCookie(w, r, CookieOauthToken, "test", time.Hour) + + http.Redirect(w, r, "/", http.StatusFound) +} diff --git a/internal/api/v1/auth/oidc_test.go b/internal/auth/oidc_test.go similarity index 99% rename from internal/api/v1/auth/oidc_test.go rename to internal/auth/oidc_test.go index 33a7293..27fe17a 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/auth/oidc_test.go @@ -227,7 +227,7 @@ func TestOIDCCallbackHandler(t *testing.T) { if tt.wantStatus == http.StatusTemporaryRedirect { setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie"))) - ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName()) + ExpectEqual(t, setCookie.Name, CookieOauthToken) ExpectTrue(t, setCookie.Value != "") ExpectEqual(t, setCookie.Path, "/") ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode) @@ -434,7 +434,7 @@ func TestCheckToken(t *testing.T) { // Craft a test HTTP request that includes the token as a cookie. req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{ - Name: auth.TokenCookieName(), + Name: CookieOauthToken, Value: signedToken, }) diff --git a/internal/api/v1/auth/provider.go b/internal/auth/provider.go similarity index 61% rename from internal/api/v1/auth/provider.go rename to internal/auth/provider.go index d377027..f56b927 100644 --- a/internal/api/v1/auth/provider.go +++ b/internal/auth/provider.go @@ -1,10 +1,7 @@ package auth -import ( - "net/http" -) +import "net/http" type Provider interface { - TokenCookieName() string CheckToken(r *http.Request) error } diff --git a/internal/api/v1/auth/userpass.go b/internal/auth/userpass.go similarity index 100% rename from internal/api/v1/auth/userpass.go rename to internal/auth/userpass.go diff --git a/internal/api/v1/auth/userpass_test.go b/internal/auth/userpass_test.go similarity index 100% rename from internal/api/v1/auth/userpass_test.go rename to internal/auth/userpass_test.go diff --git a/internal/api/v1/auth/utils.go b/internal/auth/utils.go similarity index 85% rename from internal/api/v1/auth/utils.go rename to internal/auth/utils.go index 2ddd3d1..edefe33 100644 --- a/internal/api/v1/auth/utils.go +++ b/internal/auth/utils.go @@ -1,8 +1,6 @@ package auth import ( - "crypto/rand" - "encoding/base64" "net" "net/http" "time" @@ -74,12 +72,3 @@ func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) { Path: "/", }) } - -// generateState generates a random string for OIDC state. -const oidcStateLength = 32 - -func generateState() string { - b := make([]byte, oidcStateLength) - _, _ = rand.Read(b) - return base64.URLEncoding.EncodeToString(b)[:oidcStateLength] -} diff --git a/internal/common/constants.go b/internal/common/constants.go index 17e6190..0cc7a4b 100644 --- a/internal/common/constants.go +++ b/internal/common/constants.go @@ -23,6 +23,8 @@ const ( ComposeFileName = "compose.yml" ComposeExampleFileName = "compose.example.yml" + DataDir = "data" + ErrorPagesBasePath = "error_pages" AgentCertsBasePath = "certs" diff --git a/internal/common/crypto.go b/internal/common/crypto.go index 3dcc204..12707ac 100644 --- a/internal/common/crypto.go +++ b/internal/common/crypto.go @@ -13,7 +13,7 @@ func decodeJWTKey(key string) []byte { } bytes, err := base64.StdEncoding.DecodeString(key) if err != nil { - log.Panic().Err(err).Msg("failed to decode jwt key") + log.Fatal().Str("key", key).Err(err).Msg("failed to decode secret") } return bytes } @@ -22,7 +22,7 @@ func RandomJWTKey() []byte { key := make([]byte, 32) _, err := rand.Read(key) if err != nil { - log.Panic().Err(err).Msg("failed to generate random jwt key") + log.Fatal().Err(err).Msg("failed to generate random jwt key") } return key } diff --git a/internal/common/env.go b/internal/common/env.go index 12bdcb7..5fb6a42 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -38,7 +38,7 @@ var ( APIJWTSecure = GetEnvBool("API_JWT_SECURE", true) APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", "")) - APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour) + APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", 24*time.Hour) APIUser = GetEnvString("API_USER", "admin") APIPassword = GetEnvString("API_PASSWORD", "password") diff --git a/internal/jsonstore/internal.go b/internal/jsonstore/internal.go new file mode 100644 index 0000000..e2ba875 --- /dev/null +++ b/internal/jsonstore/internal.go @@ -0,0 +1,63 @@ +package jsonstore + +import ( + "encoding/json" + "path/filepath" + "sync" + + "github.com/puzpuzpuz/xsync/v3" + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils" +) + +type jsonStoreInternal struct{ *xsync.MapOf[string, any] } +type namespace string + +var stores = make(map[namespace]jsonStoreInternal) +var storesMu sync.Mutex +var storesPath = filepath.Join(common.DataDir, "data.json") + +func Initialize() { + if err := load(); err != nil { + logging.Error().Err(err).Msg("failed to load stores") + } + + task.OnProgramExit("save_stores", func() { + if err := save(); err != nil { + logging.Error().Err(err).Msg("failed to save stores") + } + }) +} + +func load() error { + storesMu.Lock() + defer storesMu.Unlock() + if err := utils.LoadJSONIfExist(storesPath, &stores); err != nil { + return err + } + return nil +} + +func save() error { + storesMu.Lock() + defer storesMu.Unlock() + return utils.SaveJSON(storesPath, &stores, 0o644) +} + +func (s jsonStoreInternal) MarshalJSON() ([]byte, error) { + return json.Marshal(xsync.ToPlainMapOf(s.MapOf)) +} + +func (s jsonStoreInternal) UnmarshalJSON(data []byte) error { + var tmp map[string]any + if err := json.Unmarshal(data, &tmp); err != nil { + return err + } + s.MapOf = xsync.NewMapOf[string, any](xsync.WithPresize(len(tmp))) + for k, v := range tmp { + s.MapOf.Store(k, v) + } + return nil +} diff --git a/internal/jsonstore/jsonstore.go b/internal/jsonstore/jsonstore.go new file mode 100644 index 0000000..944a10b --- /dev/null +++ b/internal/jsonstore/jsonstore.go @@ -0,0 +1,47 @@ +package jsonstore + +import ( + "github.com/puzpuzpuz/xsync/v3" +) + +type JSONStore[VT any] struct{ m jsonStoreInternal } + +func NewStore[VT any](namespace namespace) JSONStore[VT] { + storesMu.Lock() + defer storesMu.Unlock() + if s, ok := stores[namespace]; ok { + return JSONStore[VT]{s} + } + m := jsonStoreInternal{xsync.NewMapOf[string, any]()} + stores[namespace] = m + return JSONStore[VT]{m} +} + +func (s JSONStore[VT]) Load(key string) (_ VT, _ bool) { + value, ok := s.m.Load(key) + if !ok { + return + } + return value.(VT), true +} + +func (s JSONStore[VT]) Has(key string) bool { + _, ok := s.m.Load(key) + return ok +} + +func (s JSONStore[VT]) Store(key string, value VT) { + s.m.Store(key, value) +} + +func (s JSONStore[VT]) Delete(key string) { + s.m.Delete(key) +} + +func (s JSONStore[VT]) Iter(yield func(key string, value VT) bool) { + for k, v := range s.m.Range { + if !yield(k, v.(VT)) { + return + } + } +} diff --git a/internal/jsonstore/jsonstore_test.go b/internal/jsonstore/jsonstore_test.go new file mode 100644 index 0000000..7fa7e04 --- /dev/null +++ b/internal/jsonstore/jsonstore_test.go @@ -0,0 +1,32 @@ +package jsonstore + +import ( + "path/filepath" + "testing" +) + +func TestNewJSON(t *testing.T) { + store := NewStore[string]("test") + store.Store("a", "1") + if v, _ := store.Load("a"); v != "1" { + t.Fatal("expected 1, got", v) + } +} + +func TestSaveLoad(t *testing.T) { + tmpDir := t.TempDir() + storesPath = filepath.Join(tmpDir, "data.json") + store := NewStore[string]("test") + store.Store("a", "1") + if err := save(); err != nil { + t.Fatal(err) + } + stores = nil + if err := load(); err != nil { + t.Fatal(err) + } + store = NewStore[string]("test") + if v, _ := store.Load("a"); v != "1" { + t.Fatal("expected 1, got", v) + } +} diff --git a/internal/net/gphttp/middleware/oidc.go b/internal/net/gphttp/middleware/oidc.go index 73b9c25..3fc62b8 100644 --- a/internal/net/gphttp/middleware/oidc.go +++ b/internal/net/gphttp/middleware/oidc.go @@ -6,7 +6,7 @@ import ( "sync" "sync/atomic" - "github.com/yusing/go-proxy/internal/api/v1/auth" + "github.com/yusing/go-proxy/internal/auth" "github.com/yusing/go-proxy/internal/gperr" ) @@ -76,13 +76,17 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce amw.auth.LogoutHandler(w, r) return false } - if err := amw.auth.CheckToken(r); err != nil { - if errors.Is(err, auth.ErrMissingToken) { - amw.auth.HandleAuth(w, r) - } else { - auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath) - } - return false + + err := amw.auth.CheckToken(r) + if err == nil { + return true } - return true + + switch { + case errors.Is(err, auth.ErrMissingToken): + amw.auth.HandleAuth(w, r) + default: + auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath) + } + return false }