mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-20 04:42:33 +02:00
feat(oidc): support token refreshing via offline_access scope
- refactored code - moved api/v1/auth to auth/ - security enhancement - env example update - default jwt ttl changed to 24 hours
This commit is contained in:
parent
28c9a2e9d0
commit
b815c6fd69
21 changed files with 668 additions and 310 deletions
17
.env.example
17
.env.example
|
@ -1,23 +1,26 @@
|
||||||
# set timezone to get correct log timestamp
|
# set timezone to get correct log timestamp
|
||||||
TZ=ETC/UTC
|
TZ=ETC/UTC
|
||||||
|
|
||||||
|
# API JWT Configuration (common)
|
||||||
|
# generate secret with `openssl rand -base64 32`
|
||||||
|
GODOXY_API_JWT_SECRET=
|
||||||
|
# the JWT token time-to-live
|
||||||
|
# leave empty to use default (24 hours)
|
||||||
|
# format: https://pkg.go.dev/time#Duration
|
||||||
|
GODOXY_API_JWT_TOKEN_TTL=
|
||||||
|
|
||||||
# API/WebUI user password login credentials (optional)
|
# API/WebUI user password login credentials (optional)
|
||||||
# These fields are not required for OIDC authentication
|
# These fields are not required for OIDC authentication
|
||||||
GODOXY_API_USER=admin
|
GODOXY_API_USER=admin
|
||||||
GODOXY_API_PASSWORD=password
|
GODOXY_API_PASSWORD=password
|
||||||
# generate secret with `openssl rand -base64 32`
|
|
||||||
GODOXY_API_JWT_SECRET=
|
|
||||||
# the JWT token time-to-live
|
|
||||||
GODOXY_API_JWT_TOKEN_TTL=1h
|
|
||||||
|
|
||||||
# OIDC Configuration (optional)
|
# OIDC Configuration (optional)
|
||||||
# Uncomment and configure these values to enable OIDC authentication.
|
# Uncomment and configure these values to enable OIDC authentication.
|
||||||
|
# For `GODOXY_OIDC_SCOPES` you may also include `offline_access` if your Idp supports it (e.g. Authentik)
|
||||||
|
#
|
||||||
# GODOXY_OIDC_ISSUER_URL=https://accounts.google.com
|
# GODOXY_OIDC_ISSUER_URL=https://accounts.google.com
|
||||||
# GODOXY_OIDC_CLIENT_ID=your-client-id
|
# GODOXY_OIDC_CLIENT_ID=your-client-id
|
||||||
# GODOXY_OIDC_CLIENT_SECRET=your-client-secret
|
# GODOXY_OIDC_CLIENT_SECRET=your-client-secret
|
||||||
# Keep /api/auth/callback as the redirect URL, change the domain to match your setup.
|
|
||||||
# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback
|
|
||||||
# Comma-separated list of scopes
|
|
||||||
# GODOXY_OIDC_SCOPES=openid, profile, email
|
# GODOXY_OIDC_SCOPES=openid, profile, email
|
||||||
#
|
#
|
||||||
# User definitions: Uncomment and configure these values to restrict access to specific users or groups.
|
# User definitions: Uncomment and configure these values to restrict access to specific users or groups.
|
||||||
|
|
|
@ -7,13 +7,14 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal"
|
"github.com/yusing/go-proxy/internal"
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/favicon"
|
"github.com/yusing/go-proxy/internal/api/v1/favicon"
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/query"
|
"github.com/yusing/go-proxy/internal/api/v1/query"
|
||||||
|
"github.com/yusing/go-proxy/internal/auth"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/config"
|
"github.com/yusing/go-proxy/internal/config"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/homepage"
|
"github.com/yusing/go-proxy/internal/homepage"
|
||||||
|
"github.com/yusing/go-proxy/internal/jsonstore"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
||||||
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
||||||
|
@ -79,6 +80,7 @@ func main() {
|
||||||
logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
|
logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
|
||||||
logging.Trace().Msg("trace enabled")
|
logging.Trace().Msg("trace enabled")
|
||||||
parallel(
|
parallel(
|
||||||
|
jsonstore.Initialize,
|
||||||
internal.InitIconListCache,
|
internal.InitIconListCache,
|
||||||
homepage.InitOverridesConfig,
|
homepage.InitOverridesConfig,
|
||||||
favicon.InitIconCache,
|
favicon.InitIconCache,
|
||||||
|
|
|
@ -6,10 +6,10 @@ import (
|
||||||
|
|
||||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||||
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/certapi"
|
"github.com/yusing/go-proxy/internal/api/v1/certapi"
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/dockerapi"
|
"github.com/yusing/go-proxy/internal/api/v1/dockerapi"
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/favicon"
|
"github.com/yusing/go-proxy/internal/api/v1/favicon"
|
||||||
|
"github.com/yusing/go-proxy/internal/auth"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
config "github.com/yusing/go-proxy/internal/config/types"
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
|
|
@ -1,272 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/url"
|
|
||||||
"slices"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/net/gphttp"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
)
|
|
||||||
|
|
||||||
type (
|
|
||||||
OIDCProvider struct {
|
|
||||||
oauthConfig *oauth2.Config
|
|
||||||
oidcProvider *oidc.Provider
|
|
||||||
oidcVerifier *oidc.IDTokenVerifier
|
|
||||||
oidcEndSessionURL *url.URL
|
|
||||||
allowedUsers []string
|
|
||||||
allowedGroups []string
|
|
||||||
}
|
|
||||||
|
|
||||||
providerJSON struct {
|
|
||||||
oidc.ProviderConfig
|
|
||||||
EndSessionURL string `json:"end_session_endpoint"`
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
const CookieOauthState = "godoxy_oidc_state"
|
|
||||||
|
|
||||||
const (
|
|
||||||
OIDCAuthCallbackPath = "/auth/callback"
|
|
||||||
OIDCPostAuthPath = "/auth/postauth"
|
|
||||||
OIDCLogoutPath = "/auth/logout"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrMissingState = errors.New("missing state cookie")
|
|
||||||
ErrInvalidState = errors.New("invalid oauth state")
|
|
||||||
)
|
|
||||||
|
|
||||||
func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) {
|
|
||||||
if len(allowedUsers)+len(allowedGroups) == 0 {
|
|
||||||
return nil, errors.New("OIDC users, groups, or both must not be empty")
|
|
||||||
}
|
|
||||||
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
endSessionURL, err := url.Parse(provider.EndSessionEndpoint())
|
|
||||||
if err != nil && provider.EndSessionEndpoint() != "" {
|
|
||||||
// non critical, just warn
|
|
||||||
logging.Warn().
|
|
||||||
Str("issuer", issuerURL).
|
|
||||||
Err(err).
|
|
||||||
Msg("failed to parse end session URL")
|
|
||||||
}
|
|
||||||
|
|
||||||
return &OIDCProvider{
|
|
||||||
oauthConfig: &oauth2.Config{
|
|
||||||
ClientID: clientID,
|
|
||||||
ClientSecret: clientSecret,
|
|
||||||
RedirectURL: "",
|
|
||||||
Endpoint: provider.Endpoint(),
|
|
||||||
Scopes: strutils.CommaSeperatedList(common.OIDCScopes),
|
|
||||||
},
|
|
||||||
oidcProvider: provider,
|
|
||||||
oidcVerifier: provider.Verifier(&oidc.Config{
|
|
||||||
ClientID: clientID,
|
|
||||||
}),
|
|
||||||
oidcEndSessionURL: endSessionURL,
|
|
||||||
allowedUsers: allowedUsers,
|
|
||||||
allowedGroups: allowedGroups,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables.
|
|
||||||
func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
|
|
||||||
return NewOIDCProvider(
|
|
||||||
common.OIDCIssuerURL,
|
|
||||||
common.OIDCClientID,
|
|
||||||
common.OIDCClientSecret,
|
|
||||||
common.OIDCAllowedUsers,
|
|
||||||
common.OIDCAllowedGroups,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *OIDCProvider) TokenCookieName() string {
|
|
||||||
return "godoxy_oidc_token"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
|
||||||
auth.allowedUsers = users
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
|
|
||||||
auth.allowedGroups = groups
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *OIDCProvider) getVerifyStateCookie(r *http.Request) (string, error) {
|
|
||||||
state, err := r.Cookie(CookieOauthState)
|
|
||||||
if err != nil {
|
|
||||||
return "", ErrMissingState
|
|
||||||
}
|
|
||||||
if r.URL.Query().Get("state") != state.Value {
|
|
||||||
return "", ErrInvalidState
|
|
||||||
}
|
|
||||||
return state.Value, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
|
|
||||||
return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch r.Method {
|
|
||||||
case http.MethodHead:
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
return
|
|
||||||
case http.MethodGet:
|
|
||||||
break
|
|
||||||
default:
|
|
||||||
gphttp.Forbidden(w, "method not allowed")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch r.URL.Path {
|
|
||||||
case OIDCAuthCallbackPath:
|
|
||||||
state := generateState()
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: CookieOauthState,
|
|
||||||
Value: state,
|
|
||||||
MaxAge: 300,
|
|
||||||
HttpOnly: true,
|
|
||||||
SameSite: http.SameSiteLaxMode,
|
|
||||||
Secure: common.APIJWTSecure,
|
|
||||||
Path: "/",
|
|
||||||
})
|
|
||||||
// redirect user to Idp
|
|
||||||
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusTemporaryRedirect)
|
|
||||||
case OIDCPostAuthPath:
|
|
||||||
auth.PostAuthCallbackHandler(w, r)
|
|
||||||
default:
|
|
||||||
auth.LogoutHandler(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
|
||||||
token, err := r.Cookie(auth.TokenCookieName())
|
|
||||||
if err != nil {
|
|
||||||
return ErrMissingToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// checks for Expiry, Audience == ClientID, Issuer, etc.
|
|
||||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(idToken.Audience) == 0 {
|
|
||||||
return ErrInvalidToken
|
|
||||||
}
|
|
||||||
|
|
||||||
var claims struct {
|
|
||||||
Email string `json:"email"`
|
|
||||||
Username string `json:"preferred_username"`
|
|
||||||
Groups []string `json:"groups"`
|
|
||||||
}
|
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
|
||||||
return fmt.Errorf("failed to parse claims: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Logical AND between allowed users and groups.
|
|
||||||
allowedUser := slices.Contains(auth.allowedUsers, claims.Username)
|
|
||||||
allowedGroup := len(utils.Intersect(claims.Groups, auth.allowedGroups)) > 0
|
|
||||||
if !allowedUser && !allowedGroup {
|
|
||||||
return ErrUserNotAllowed
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
// For testing purposes, skip provider verification
|
|
||||||
if common.IsTest {
|
|
||||||
auth.handleTestCallback(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := auth.getVerifyStateCookie(r)
|
|
||||||
if err != nil {
|
|
||||||
gphttp.BadRequest(w, err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
code := r.URL.Query().Get("code")
|
|
||||||
oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r))
|
|
||||||
if err != nil {
|
|
||||||
gphttp.ServerError(w, r, fmt.Errorf("failed to exchange token: %w", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
|
||||||
if !ok {
|
|
||||||
gphttp.BadRequest(w, "missing id_token")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), rawIDToken)
|
|
||||||
if err != nil {
|
|
||||||
gphttp.ServerError(w, r, fmt.Errorf("failed to verify ID token: %w", err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
setTokenCookie(w, r, auth.TokenCookieName(), rawIDToken, time.Until(idToken.Expiry))
|
|
||||||
|
|
||||||
// Redirect to home page
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if auth.oidcEndSessionURL == nil {
|
|
||||||
clearTokenCookie(w, r, auth.TokenCookieName())
|
|
||||||
http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusTemporaryRedirect)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := r.Cookie(auth.TokenCookieName())
|
|
||||||
if err == nil {
|
|
||||||
query := auth.oidcEndSessionURL.Query()
|
|
||||||
query.Add("id_token_hint", token.Value)
|
|
||||||
|
|
||||||
logoutURL := *auth.oidcEndSessionURL
|
|
||||||
logoutURL.RawQuery = query.Encode()
|
|
||||||
|
|
||||||
clearTokenCookie(w, r, auth.TokenCookieName())
|
|
||||||
http.Redirect(w, r, logoutURL.String(), http.StatusFound)
|
|
||||||
}
|
|
||||||
|
|
||||||
http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusTemporaryRedirect)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleTestCallback handles OIDC callback in test environment.
|
|
||||||
func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
|
||||||
state, err := r.Cookie(CookieOauthState)
|
|
||||||
if err != nil {
|
|
||||||
gphttp.BadRequest(w, "missing state cookie")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.URL.Query().Get("state") != state.Value {
|
|
||||||
gphttp.BadRequest(w, "invalid oauth state")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create test JWT token
|
|
||||||
setTokenCookie(w, r, auth.TokenCookieName(), "test", time.Hour)
|
|
||||||
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
}
|
|
181
internal/auth/oauth_refresh.go
Normal file
181
internal/auth/oauth_refresh.go
Normal file
|
@ -0,0 +1,181 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
"github.com/yusing/go-proxy/internal/jsonstore"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type oauthRefreshToken struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
Expiry time.Time `json:"expiry"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
SessionID sessionID `json:"session_id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Groups []string `json:"groups"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionClaims struct {
|
||||||
|
Session
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
type sessionID string
|
||||||
|
|
||||||
|
var oauthRefreshTokens jsonstore.JSONStore[oauthRefreshToken]
|
||||||
|
|
||||||
|
var (
|
||||||
|
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
|
||||||
|
refreshBefore = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
const sessionTokenIssuer = "GoDoxy"
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if IsOIDCEnabled() {
|
||||||
|
oauthRefreshTokens = jsonstore.NewStore[oauthRefreshToken]("oauth_refresh_tokens")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (token *oauthRefreshToken) expired() bool {
|
||||||
|
return time.Now().After(token.Expiry)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSessionID() sessionID {
|
||||||
|
b := make([]byte, 32)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return sessionID(base64.StdEncoding.EncodeToString(b))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSession(username string, groups []string) Session {
|
||||||
|
return Session{
|
||||||
|
SessionID: newSessionID(),
|
||||||
|
Username: username,
|
||||||
|
Groups: groups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
|
||||||
|
token, ok := oauthRefreshTokens.Load(string(claims.SessionID))
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
invalidateOAuthRefreshToken(claims.SessionID)
|
||||||
|
if token.expired() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
if claims.Username != token.Username {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return &token, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
|
||||||
|
logging.Debug().Str("username", username).Msg("setting oauth refresh token")
|
||||||
|
oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{
|
||||||
|
Username: username,
|
||||||
|
RefreshToken: token,
|
||||||
|
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func invalidateOAuthRefreshToken(sessionID sessionID) {
|
||||||
|
oauthRefreshTokens.Delete(string(sessionID))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) setSessionTokenCookie(w http.ResponseWriter, r *http.Request, session Session) {
|
||||||
|
claims := &sessionClaims{
|
||||||
|
Session: session,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
Issuer: sessionTokenIssuer,
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(common.APIJWTTokenTTL)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS512, claims)
|
||||||
|
signed, err := jwtToken.SignedString(common.APIJWTSecret)
|
||||||
|
if err != nil {
|
||||||
|
logging.Err(err).Msg("failed to sign session token")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
setTokenCookie(w, r, CookieOauthSessionToken, signed, common.APIJWTTokenTTL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionClaims, valid bool, err error) {
|
||||||
|
claims = &sessionClaims{}
|
||||||
|
sessionToken, err := jwt.ParseWithClaims(sessionJWT, claims, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||||
|
}
|
||||||
|
return common.APIJWTSecret, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
// check for session token
|
||||||
|
sessionCookie, err := r.Cookie(CookieOauthSessionToken)
|
||||||
|
if err != nil {
|
||||||
|
return ErrMissingToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify the session cookie
|
||||||
|
claims, valid, err := auth.parseSessionJWT(sessionCookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
||||||
|
}
|
||||||
|
if !valid {
|
||||||
|
return ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
// check if refresh is possible
|
||||||
|
refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session)
|
||||||
|
if !ok {
|
||||||
|
return ErrMissingToken
|
||||||
|
}
|
||||||
|
|
||||||
|
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
||||||
|
return ErrUserNotAllowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// this step refreshes the token
|
||||||
|
// see https://cs.opensource.google/go/x/oauth2/+/refs/tags/v0.29.0:oauth2.go;l=313
|
||||||
|
newToken, err := auth.oauthConfig.TokenSource(r.Context(), &oauth2.Token{
|
||||||
|
RefreshToken: refreshToken.RefreshToken,
|
||||||
|
}).Token()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %w", ErrRefreshTokenFailure, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
idTokenJWT, idToken, err := auth.getIdToken(r.Context(), newToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := newSessionID()
|
||||||
|
|
||||||
|
logging.Debug().Str("username", claims.Username).Time("expiry", newToken.Expiry).Msg("refreshed token")
|
||||||
|
storeOAuthRefreshToken(sessionID, claims.Username, newToken.RefreshToken)
|
||||||
|
|
||||||
|
// set new idToken and new sessionToken
|
||||||
|
auth.setIDTokenCookie(w, r, idTokenJWT, time.Until(idToken.Expiry))
|
||||||
|
auth.setSessionTokenCookie(w, r, Session{
|
||||||
|
SessionID: sessionID,
|
||||||
|
Username: claims.Username,
|
||||||
|
Groups: claims.Groups,
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
310
internal/auth/oidc.go
Normal file
310
internal/auth/oidc.go
Normal file
|
@ -0,0 +1,310 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
"github.com/yusing/go-proxy/internal/net/gphttp"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
type (
|
||||||
|
OIDCProvider struct {
|
||||||
|
oauthConfig *oauth2.Config
|
||||||
|
oidcProvider *oidc.Provider
|
||||||
|
oidcVerifier *oidc.IDTokenVerifier
|
||||||
|
endSessionURL *url.URL
|
||||||
|
allowedUsers []string
|
||||||
|
allowedGroups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
IDTokenClaims struct {
|
||||||
|
Username string `json:"preferred_username"`
|
||||||
|
Groups []string `json:"groups"`
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CookieOauthState = "godoxy_oidc_state"
|
||||||
|
CookieOauthSessionID = "godoxy_session_id"
|
||||||
|
CookieOauthToken = "godoxy_token"
|
||||||
|
CookieOauthSessionToken = "godoxy_session_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
OIDCAuthCallbackPath = "/auth/callback"
|
||||||
|
OIDCPostAuthPath = "/auth/postauth"
|
||||||
|
OIDCLogoutPath = "/auth/logout"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMissingIDToken = errors.New("missing id_token")
|
||||||
|
ErrRefreshTokenFailure = errors.New("failed to refresh token")
|
||||||
|
)
|
||||||
|
|
||||||
|
// generateState generates a random string for OIDC state.
|
||||||
|
const oidcStateLength = 32
|
||||||
|
|
||||||
|
func generateState() string {
|
||||||
|
b := make([]byte, oidcStateLength)
|
||||||
|
_, _ = rand.Read(b)
|
||||||
|
return base64.URLEncoding.EncodeToString(b)[:oidcStateLength]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOIDCProvider(issuerURL, clientID, clientSecret string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) {
|
||||||
|
if len(allowedUsers)+len(allowedGroups) == 0 {
|
||||||
|
return nil, errors.New("OIDC users, groups, or both must not be empty")
|
||||||
|
}
|
||||||
|
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
endSessionURL, err := url.Parse(provider.EndSessionEndpoint())
|
||||||
|
if err != nil && provider.EndSessionEndpoint() != "" {
|
||||||
|
// non critical, just warn
|
||||||
|
logging.Warn().
|
||||||
|
Str("issuer", issuerURL).
|
||||||
|
Err(err).
|
||||||
|
Msg("failed to parse end session URL")
|
||||||
|
}
|
||||||
|
|
||||||
|
return &OIDCProvider{
|
||||||
|
oauthConfig: &oauth2.Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
ClientSecret: clientSecret,
|
||||||
|
RedirectURL: "",
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
Scopes: strutils.CommaSeperatedList(common.OIDCScopes),
|
||||||
|
},
|
||||||
|
oidcProvider: provider,
|
||||||
|
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
}),
|
||||||
|
endSessionURL: endSessionURL,
|
||||||
|
allowedUsers: allowedUsers,
|
||||||
|
allowedGroups: allowedGroups,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables.
|
||||||
|
func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
|
||||||
|
return NewOIDCProvider(
|
||||||
|
common.OIDCIssuerURL,
|
||||||
|
common.OIDCClientID,
|
||||||
|
common.OIDCClientSecret,
|
||||||
|
common.OIDCAllowedUsers,
|
||||||
|
common.OIDCAllowedGroups,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
||||||
|
auth.allowedUsers = users
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
|
||||||
|
auth.allowedGroups = groups
|
||||||
|
}
|
||||||
|
|
||||||
|
func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
|
||||||
|
return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
|
||||||
|
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
|
||||||
|
if !ok {
|
||||||
|
return "", nil, ErrMissingIDToken
|
||||||
|
}
|
||||||
|
idToken, err := auth.oidcVerifier.Verify(ctx, idTokenJWT)
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, fmt.Errorf("failed to verify ID token: %w", err)
|
||||||
|
}
|
||||||
|
return idTokenJWT, idToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// check for session token
|
||||||
|
_, err := r.Cookie(CookieOauthSessionToken)
|
||||||
|
if err == nil {
|
||||||
|
err := auth.TryRefreshToken(w, r)
|
||||||
|
if err != nil {
|
||||||
|
logging.Debug().Err(err).Msg("failed to refresh token")
|
||||||
|
auth.LogoutHandler(w, r)
|
||||||
|
} else {
|
||||||
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch r.URL.Path {
|
||||||
|
case OIDCAuthCallbackPath:
|
||||||
|
state := generateState()
|
||||||
|
setTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
|
||||||
|
// redirect user to Idp
|
||||||
|
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound)
|
||||||
|
case OIDCPostAuthPath:
|
||||||
|
auth.PostAuthCallbackHandler(w, r)
|
||||||
|
case OIDCLogoutPath:
|
||||||
|
auth.LogoutHandler(w, r)
|
||||||
|
default:
|
||||||
|
http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusFound)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
|
||||||
|
var claim IDTokenClaims
|
||||||
|
if err := idToken.Claims(&claim); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse claims: %w", err)
|
||||||
|
}
|
||||||
|
if claim.Username == "" {
|
||||||
|
return nil, fmt.Errorf("missing username in ID token")
|
||||||
|
}
|
||||||
|
return &claim, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
|
||||||
|
userAllowed := slices.Contains(auth.allowedUsers, user)
|
||||||
|
if !userAllowed {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(auth.allowedGroups) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return len(utils.Intersect(groups, auth.allowedGroups)) > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||||
|
tokenCookie, err := r.Cookie(CookieOauthToken)
|
||||||
|
if err != nil {
|
||||||
|
return ErrMissingToken
|
||||||
|
}
|
||||||
|
|
||||||
|
idToken, err := auth.oidcVerifier.Verify(r.Context(), tokenCookie.Value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, err := parseClaims(idToken)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
||||||
|
return ErrUserNotAllowed
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) PostAuthCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// For testing purposes, skip provider verification
|
||||||
|
if common.IsTest {
|
||||||
|
auth.handleTestCallback(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// verify state
|
||||||
|
state, err := r.Cookie(CookieOauthState)
|
||||||
|
if err != nil {
|
||||||
|
gphttp.BadRequest(w, "missing state cookie")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if r.URL.Query().Get("state") != state.Value {
|
||||||
|
gphttp.BadRequest(w, "invalid oauth state")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
code := r.URL.Query().Get("code")
|
||||||
|
oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code, optRedirectPostAuth(r))
|
||||||
|
if err != nil {
|
||||||
|
gphttp.ServerError(w, r, fmt.Errorf("failed to exchange token: %w", err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
idTokenJWT, idToken, err := auth.getIdToken(r.Context(), oauth2Token)
|
||||||
|
if err != nil {
|
||||||
|
gphttp.ServerError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if oauth2Token.RefreshToken != "" {
|
||||||
|
claims, err := parseClaims(idToken)
|
||||||
|
if err != nil {
|
||||||
|
gphttp.ServerError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
session := newSession(claims.Username, claims.Groups)
|
||||||
|
storeOAuthRefreshToken(session.SessionID, claims.Username, oauth2Token.RefreshToken)
|
||||||
|
auth.setSessionTokenCookie(w, r, session)
|
||||||
|
}
|
||||||
|
auth.setIDTokenCookie(w, r, idTokenJWT, time.Until(idToken.Expiry))
|
||||||
|
|
||||||
|
// Redirect to home page
|
||||||
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
oauthToken, _ := r.Cookie(CookieOauthToken)
|
||||||
|
sessionToken, _ := r.Cookie(CookieOauthSessionToken)
|
||||||
|
auth.clearCookie(w, r)
|
||||||
|
|
||||||
|
if sessionToken != nil {
|
||||||
|
claims, _, err := auth.parseSessionJWT(sessionToken.Value)
|
||||||
|
if err == nil {
|
||||||
|
invalidateOAuthRefreshToken(claims.SessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
url := "/"
|
||||||
|
if auth.endSessionURL != nil && oauthToken != nil {
|
||||||
|
query := auth.endSessionURL.Query()
|
||||||
|
query.Set("id_token_hint", oauthToken.Value)
|
||||||
|
|
||||||
|
clone := *auth.endSessionURL
|
||||||
|
clone.RawQuery = query.Encode()
|
||||||
|
url = clone.String()
|
||||||
|
} else if auth.endSessionURL != nil {
|
||||||
|
url = auth.endSessionURL.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
http.Redirect(w, r, url, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) setIDTokenCookie(w http.ResponseWriter, r *http.Request, jwt string, ttl time.Duration) {
|
||||||
|
setTokenCookie(w, r, CookieOauthToken, jwt, ttl)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) clearCookie(w http.ResponseWriter, r *http.Request) {
|
||||||
|
clearTokenCookie(w, r, CookieOauthToken)
|
||||||
|
clearTokenCookie(w, r, CookieOauthSessionToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleTestCallback handles OIDC callback in test environment.
|
||||||
|
func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
state, err := r.Cookie(CookieOauthState)
|
||||||
|
if err != nil {
|
||||||
|
gphttp.BadRequest(w, "missing state cookie")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if r.URL.Query().Get("state") != state.Value {
|
||||||
|
gphttp.BadRequest(w, "invalid oauth state")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test JWT token
|
||||||
|
setTokenCookie(w, r, CookieOauthToken, "test", time.Hour)
|
||||||
|
|
||||||
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
|
}
|
|
@ -227,7 +227,7 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
||||||
|
|
||||||
if tt.wantStatus == http.StatusTemporaryRedirect {
|
if tt.wantStatus == http.StatusTemporaryRedirect {
|
||||||
setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
setCookie := Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
||||||
ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName())
|
ExpectEqual(t, setCookie.Name, CookieOauthToken)
|
||||||
ExpectTrue(t, setCookie.Value != "")
|
ExpectTrue(t, setCookie.Value != "")
|
||||||
ExpectEqual(t, setCookie.Path, "/")
|
ExpectEqual(t, setCookie.Path, "/")
|
||||||
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
||||||
|
@ -434,7 +434,7 @@ func TestCheckToken(t *testing.T) {
|
||||||
// Craft a test HTTP request that includes the token as a cookie.
|
// Craft a test HTTP request that includes the token as a cookie.
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
req.AddCookie(&http.Cookie{
|
req.AddCookie(&http.Cookie{
|
||||||
Name: auth.TokenCookieName(),
|
Name: CookieOauthToken,
|
||||||
Value: signedToken,
|
Value: signedToken,
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,10 +1,7 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import "net/http"
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
TokenCookieName() string
|
|
||||||
CheckToken(r *http.Request) error
|
CheckToken(r *http.Request) error
|
||||||
}
|
}
|
|
@ -1,8 +1,6 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -74,12 +72,3 @@ func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
|
||||||
Path: "/",
|
Path: "/",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateState generates a random string for OIDC state.
|
|
||||||
const oidcStateLength = 32
|
|
||||||
|
|
||||||
func generateState() string {
|
|
||||||
b := make([]byte, oidcStateLength)
|
|
||||||
_, _ = rand.Read(b)
|
|
||||||
return base64.URLEncoding.EncodeToString(b)[:oidcStateLength]
|
|
||||||
}
|
|
|
@ -23,6 +23,8 @@ const (
|
||||||
ComposeFileName = "compose.yml"
|
ComposeFileName = "compose.yml"
|
||||||
ComposeExampleFileName = "compose.example.yml"
|
ComposeExampleFileName = "compose.example.yml"
|
||||||
|
|
||||||
|
DataDir = "data"
|
||||||
|
|
||||||
ErrorPagesBasePath = "error_pages"
|
ErrorPagesBasePath = "error_pages"
|
||||||
|
|
||||||
AgentCertsBasePath = "certs"
|
AgentCertsBasePath = "certs"
|
||||||
|
|
|
@ -13,7 +13,7 @@ func decodeJWTKey(key string) []byte {
|
||||||
}
|
}
|
||||||
bytes, err := base64.StdEncoding.DecodeString(key)
|
bytes, err := base64.StdEncoding.DecodeString(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic().Err(err).Msg("failed to decode jwt key")
|
log.Fatal().Str("key", key).Err(err).Msg("failed to decode secret")
|
||||||
}
|
}
|
||||||
return bytes
|
return bytes
|
||||||
}
|
}
|
||||||
|
@ -22,7 +22,7 @@ func RandomJWTKey() []byte {
|
||||||
key := make([]byte, 32)
|
key := make([]byte, 32)
|
||||||
_, err := rand.Read(key)
|
_, err := rand.Read(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Panic().Err(err).Msg("failed to generate random jwt key")
|
log.Fatal().Err(err).Msg("failed to generate random jwt key")
|
||||||
}
|
}
|
||||||
return key
|
return key
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ var (
|
||||||
|
|
||||||
APIJWTSecure = GetEnvBool("API_JWT_SECURE", true)
|
APIJWTSecure = GetEnvBool("API_JWT_SECURE", true)
|
||||||
APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", ""))
|
APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", ""))
|
||||||
APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour)
|
APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", 24*time.Hour)
|
||||||
APIUser = GetEnvString("API_USER", "admin")
|
APIUser = GetEnvString("API_USER", "admin")
|
||||||
APIPassword = GetEnvString("API_PASSWORD", "password")
|
APIPassword = GetEnvString("API_PASSWORD", "password")
|
||||||
|
|
||||||
|
|
63
internal/jsonstore/internal.go
Normal file
63
internal/jsonstore/internal.go
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
package jsonstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type jsonStoreInternal struct{ *xsync.MapOf[string, any] }
|
||||||
|
type namespace string
|
||||||
|
|
||||||
|
var stores = make(map[namespace]jsonStoreInternal)
|
||||||
|
var storesMu sync.Mutex
|
||||||
|
var storesPath = filepath.Join(common.DataDir, "data.json")
|
||||||
|
|
||||||
|
func Initialize() {
|
||||||
|
if err := load(); err != nil {
|
||||||
|
logging.Error().Err(err).Msg("failed to load stores")
|
||||||
|
}
|
||||||
|
|
||||||
|
task.OnProgramExit("save_stores", func() {
|
||||||
|
if err := save(); err != nil {
|
||||||
|
logging.Error().Err(err).Msg("failed to save stores")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func load() error {
|
||||||
|
storesMu.Lock()
|
||||||
|
defer storesMu.Unlock()
|
||||||
|
if err := utils.LoadJSONIfExist(storesPath, &stores); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func save() error {
|
||||||
|
storesMu.Lock()
|
||||||
|
defer storesMu.Unlock()
|
||||||
|
return utils.SaveJSON(storesPath, &stores, 0o644)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s jsonStoreInternal) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(xsync.ToPlainMapOf(s.MapOf))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s jsonStoreInternal) UnmarshalJSON(data []byte) error {
|
||||||
|
var tmp map[string]any
|
||||||
|
if err := json.Unmarshal(data, &tmp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.MapOf = xsync.NewMapOf[string, any](xsync.WithPresize(len(tmp)))
|
||||||
|
for k, v := range tmp {
|
||||||
|
s.MapOf.Store(k, v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
47
internal/jsonstore/jsonstore.go
Normal file
47
internal/jsonstore/jsonstore.go
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
package jsonstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JSONStore[VT any] struct{ m jsonStoreInternal }
|
||||||
|
|
||||||
|
func NewStore[VT any](namespace namespace) JSONStore[VT] {
|
||||||
|
storesMu.Lock()
|
||||||
|
defer storesMu.Unlock()
|
||||||
|
if s, ok := stores[namespace]; ok {
|
||||||
|
return JSONStore[VT]{s}
|
||||||
|
}
|
||||||
|
m := jsonStoreInternal{xsync.NewMapOf[string, any]()}
|
||||||
|
stores[namespace] = m
|
||||||
|
return JSONStore[VT]{m}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JSONStore[VT]) Load(key string) (_ VT, _ bool) {
|
||||||
|
value, ok := s.m.Load(key)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return value.(VT), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JSONStore[VT]) Has(key string) bool {
|
||||||
|
_, ok := s.m.Load(key)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JSONStore[VT]) Store(key string, value VT) {
|
||||||
|
s.m.Store(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JSONStore[VT]) Delete(key string) {
|
||||||
|
s.m.Delete(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s JSONStore[VT]) Iter(yield func(key string, value VT) bool) {
|
||||||
|
for k, v := range s.m.Range {
|
||||||
|
if !yield(k, v.(VT)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
32
internal/jsonstore/jsonstore_test.go
Normal file
32
internal/jsonstore/jsonstore_test.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package jsonstore
|
||||||
|
|
||||||
|
import (
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewJSON(t *testing.T) {
|
||||||
|
store := NewStore[string]("test")
|
||||||
|
store.Store("a", "1")
|
||||||
|
if v, _ := store.Load("a"); v != "1" {
|
||||||
|
t.Fatal("expected 1, got", v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveLoad(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
storesPath = filepath.Join(tmpDir, "data.json")
|
||||||
|
store := NewStore[string]("test")
|
||||||
|
store.Store("a", "1")
|
||||||
|
if err := save(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
stores = nil
|
||||||
|
if err := load(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
store = NewStore[string]("test")
|
||||||
|
if v, _ := store.Load("a"); v != "1" {
|
||||||
|
t.Fatal("expected 1, got", v)
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,7 +6,7 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
"github.com/yusing/go-proxy/internal/auth"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -76,13 +76,17 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
|
||||||
amw.auth.LogoutHandler(w, r)
|
amw.auth.LogoutHandler(w, r)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if err := amw.auth.CheckToken(r); err != nil {
|
|
||||||
if errors.Is(err, auth.ErrMissingToken) {
|
err := amw.auth.CheckToken(r)
|
||||||
|
if err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, auth.ErrMissingToken):
|
||||||
amw.auth.HandleAuth(w, r)
|
amw.auth.HandleAuth(w, r)
|
||||||
} else {
|
default:
|
||||||
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
|
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue