mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fix: json store marshaling, api handler
- code clean up - uncomment and simplify api auth handler - fix redirect url for frontend - proper redirect
This commit is contained in:
parent
b815c6fd69
commit
7461344004
14 changed files with 234 additions and 213 deletions
|
@ -14,7 +14,6 @@ import (
|
||||||
"github.com/yusing/go-proxy/internal/config"
|
"github.com/yusing/go-proxy/internal/config"
|
||||||
"github.com/yusing/go-proxy/internal/gperr"
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
"github.com/yusing/go-proxy/internal/homepage"
|
"github.com/yusing/go-proxy/internal/homepage"
|
||||||
"github.com/yusing/go-proxy/internal/jsonstore"
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
||||||
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
||||||
|
@ -80,7 +79,6 @@ func main() {
|
||||||
logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
|
logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
|
||||||
logging.Trace().Msg("trace enabled")
|
logging.Trace().Msg("trace enabled")
|
||||||
parallel(
|
parallel(
|
||||||
jsonstore.Initialize,
|
|
||||||
internal.InitIconListCache,
|
internal.InitIconListCache,
|
||||||
homepage.InitOverridesConfig,
|
homepage.InitOverridesConfig,
|
||||||
favicon.InitIconCache,
|
favicon.InitIconCache,
|
||||||
|
|
|
@ -98,21 +98,18 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
|
||||||
logging.Info().Msg("prometheus metrics enabled")
|
logging.Info().Msg("prometheus metrics enabled")
|
||||||
}
|
}
|
||||||
|
|
||||||
// defaultAuth := auth.GetDefaultAuth()
|
defaultAuth := auth.GetDefaultAuth()
|
||||||
// if defaultAuth != nil {
|
if defaultAuth == nil {
|
||||||
// mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage)
|
return mux
|
||||||
// mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) {
|
}
|
||||||
// if err := defaultAuth.CheckToken(r); err != nil {
|
|
||||||
// http.Error(w, err.Error(), http.StatusUnauthorized)
|
mux.HandleFunc("GET", "/v1/auth/check", auth.AuthCheckHandler)
|
||||||
// return
|
mux.HandleFunc("GET", "/v1/auth/login", defaultAuth.LoginHandler)
|
||||||
// }
|
mux.HandleFunc("GET", "/v1/auth/callback", defaultAuth.LoginHandler)
|
||||||
// })
|
mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutHandler)
|
||||||
// mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler)
|
switch authProvider := defaultAuth.(type) {
|
||||||
// mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutCallbackHandler)
|
case *auth.OIDCProvider:
|
||||||
// } else {
|
mux.HandleFunc("GET", "/v1/auth/postauth", authProvider.PostAuthCallbackHandler)
|
||||||
// mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) {
|
}
|
||||||
// w.WriteHeader(http.StatusOK)
|
|
||||||
// })
|
|
||||||
// }
|
|
||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
|
@ -50,3 +50,11 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
}
|
}
|
||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func AuthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if err := defaultAuth.CheckToken(r); err != nil {
|
||||||
|
http.Redirect(w, r, "/v1/auth/login", http.StatusFound)
|
||||||
|
} else {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package auth
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
@ -33,18 +34,23 @@ type sessionClaims struct {
|
||||||
|
|
||||||
type sessionID string
|
type sessionID string
|
||||||
|
|
||||||
var oauthRefreshTokens jsonstore.JSONStore[oauthRefreshToken]
|
var oauthRefreshTokens jsonstore.Typed[oauthRefreshToken]
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
|
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
|
||||||
refreshBefore = 30 * time.Second
|
refreshBefore = 30 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errNoRefreshToken = errors.New("no refresh token")
|
||||||
|
ErrRefreshTokenFailure = errors.New("failed to refresh token")
|
||||||
|
)
|
||||||
|
|
||||||
const sessionTokenIssuer = "GoDoxy"
|
const sessionTokenIssuer = "GoDoxy"
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
if IsOIDCEnabled() {
|
if IsOIDCEnabled() {
|
||||||
oauthRefreshTokens = jsonstore.NewStore[oauthRefreshToken]("oauth_refresh_tokens")
|
oauthRefreshTokens = jsonstore.Store[oauthRefreshToken]("oauth_refresh_tokens")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,6 +72,9 @@ func newSession(username string, groups []string) Session {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getOnceOAuthRefreshToken returns the refresh token for the given session.
|
||||||
|
//
|
||||||
|
// The token is removed from the store after retrieval.
|
||||||
func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
|
func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
|
||||||
token, ok := oauthRefreshTokens.Load(string(claims.SessionID))
|
token, ok := oauthRefreshTokens.Load(string(claims.SessionID))
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -82,15 +91,16 @@ func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
|
func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
|
||||||
logging.Debug().Str("username", username).Msg("setting oauth refresh token")
|
|
||||||
oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{
|
oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{
|
||||||
Username: username,
|
Username: username,
|
||||||
RefreshToken: token,
|
RefreshToken: token,
|
||||||
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
|
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
|
||||||
})
|
})
|
||||||
|
logging.Debug().Str("username", username).Msg("stored oauth refresh token")
|
||||||
}
|
}
|
||||||
|
|
||||||
func invalidateOAuthRefreshToken(sessionID sessionID) {
|
func invalidateOAuthRefreshToken(sessionID sessionID) {
|
||||||
|
logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
|
||||||
oauthRefreshTokens.Delete(string(sessionID))
|
oauthRefreshTokens.Delete(string(sessionID))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -125,26 +135,20 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla
|
||||||
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
|
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request) error {
|
func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request, sessionJWT string) error {
|
||||||
// check for session token
|
|
||||||
sessionCookie, err := r.Cookie(CookieOauthSessionToken)
|
|
||||||
if err != nil {
|
|
||||||
return ErrMissingToken
|
|
||||||
}
|
|
||||||
|
|
||||||
// verify the session cookie
|
// verify the session cookie
|
||||||
claims, valid, err := auth.parseSessionJWT(sessionCookie.Value)
|
claims, valid, err := auth.parseSessionJWT(sessionJWT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
return fmt.Errorf("%w: %w", ErrInvalidSessionToken, err)
|
||||||
}
|
}
|
||||||
if !valid {
|
if !valid {
|
||||||
return ErrInvalidToken
|
return ErrInvalidSessionToken
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if refresh is possible
|
// check if refresh is possible
|
||||||
refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session)
|
refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session)
|
||||||
if !ok {
|
if !ok {
|
||||||
return ErrMissingToken
|
return errNoRefreshToken
|
||||||
}
|
}
|
||||||
|
|
||||||
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
||||||
|
|
|
@ -39,20 +39,17 @@ type (
|
||||||
const (
|
const (
|
||||||
CookieOauthState = "godoxy_oidc_state"
|
CookieOauthState = "godoxy_oidc_state"
|
||||||
CookieOauthSessionID = "godoxy_session_id"
|
CookieOauthSessionID = "godoxy_session_id"
|
||||||
CookieOauthToken = "godoxy_token"
|
CookieOauthToken = "godoxy_oauth_token"
|
||||||
CookieOauthSessionToken = "godoxy_session_token"
|
CookieOauthSessionToken = "godoxy_session_token"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
OIDCAuthCallbackPath = "/auth/callback"
|
OIDCAuthInitPath = "/auth/init"
|
||||||
OIDCPostAuthPath = "/auth/postauth"
|
OIDCPostAuthPath = "/auth/postauth"
|
||||||
OIDCLogoutPath = "/auth/logout"
|
OIDCLogoutPath = "/auth/logout"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var errMissingIDToken = errors.New("missing id_token field from oauth token")
|
||||||
ErrMissingIDToken = errors.New("missing id_token")
|
|
||||||
ErrRefreshTokenFailure = errors.New("failed to refresh token")
|
|
||||||
)
|
|
||||||
|
|
||||||
// generateState generates a random string for OIDC state.
|
// generateState generates a random string for OIDC state.
|
||||||
const oidcStateLength = 32
|
const oidcStateLength = 32
|
||||||
|
@ -118,14 +115,17 @@ func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
|
||||||
auth.allowedGroups = groups
|
auth.allowedGroups = groups
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// optRedirectPostAuth returns an oauth2 option that sets the "redirect_uri"
|
||||||
|
// parameter of the authorization URL to the post auth path of the current
|
||||||
|
// request host.
|
||||||
func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
|
func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
|
||||||
return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath)
|
return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
|
func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
|
||||||
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
|
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
return "", nil, ErrMissingIDToken
|
return "", nil, errMissingIDToken
|
||||||
}
|
}
|
||||||
idToken, err := auth.oidcVerifier.Verify(ctx, idTokenJWT)
|
idToken, err := auth.oidcVerifier.Verify(ctx, idTokenJWT)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -135,34 +135,37 @@ func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Tok
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
||||||
// check for session token
|
|
||||||
_, err := r.Cookie(CookieOauthSessionToken)
|
|
||||||
if err == nil {
|
|
||||||
err := auth.TryRefreshToken(w, r)
|
|
||||||
if err != nil {
|
|
||||||
logging.Debug().Err(err).Msg("failed to refresh token")
|
|
||||||
auth.LogoutHandler(w, r)
|
|
||||||
} else {
|
|
||||||
http.Redirect(w, r, "/", http.StatusFound)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch r.URL.Path {
|
switch r.URL.Path {
|
||||||
case OIDCAuthCallbackPath:
|
case OIDCAuthInitPath:
|
||||||
state := generateState()
|
auth.LoginHandler(w, r)
|
||||||
setTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
|
|
||||||
// redirect user to Idp
|
|
||||||
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound)
|
|
||||||
case OIDCPostAuthPath:
|
case OIDCPostAuthPath:
|
||||||
auth.PostAuthCallbackHandler(w, r)
|
auth.PostAuthCallbackHandler(w, r)
|
||||||
case OIDCLogoutPath:
|
case OIDCLogoutPath:
|
||||||
auth.LogoutHandler(w, r)
|
auth.LogoutHandler(w, r)
|
||||||
default:
|
default:
|
||||||
http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusFound)
|
http.Redirect(w, r, OIDCAuthInitPath, http.StatusFound)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// check for session token
|
||||||
|
sessionToken, err := r.Cookie(CookieOauthSessionToken)
|
||||||
|
if err == nil {
|
||||||
|
err = auth.TryRefreshToken(w, r, sessionToken.Value)
|
||||||
|
if err != nil {
|
||||||
|
logging.Debug().Err(err).Msg("failed to refresh token")
|
||||||
|
auth.clearCookie(w, r)
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := generateState()
|
||||||
|
setTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
|
||||||
|
// redirect user to Idp
|
||||||
|
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
|
func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
|
||||||
var claim IDTokenClaims
|
var claim IDTokenClaims
|
||||||
if err := idToken.Claims(&claim); err != nil {
|
if err := idToken.Claims(&claim); err != nil {
|
||||||
|
@ -188,17 +191,17 @@ func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
|
||||||
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||||
tokenCookie, err := r.Cookie(CookieOauthToken)
|
tokenCookie, err := r.Cookie(CookieOauthToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrMissingToken
|
return ErrMissingOAuthToken
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), tokenCookie.Value)
|
idToken, err := auth.oidcVerifier.Verify(r.Context(), tokenCookie.Value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
claims, err := parseClaims(idToken)
|
claims, err := parseClaims(idToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
||||||
|
@ -270,6 +273,7 @@ func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request)
|
||||||
if auth.endSessionURL != nil && oauthToken != nil {
|
if auth.endSessionURL != nil && oauthToken != nil {
|
||||||
query := auth.endSessionURL.Query()
|
query := auth.endSessionURL.Query()
|
||||||
query.Set("id_token_hint", oauthToken.Value)
|
query.Set("id_token_hint", oauthToken.Value)
|
||||||
|
query.Set("post_logout_redirect_uri", "https://"+requestHost(r))
|
||||||
|
|
||||||
clone := *auth.endSessionURL
|
clone := *auth.endSessionURL
|
||||||
clone.RawQuery = query.Encode()
|
clone.RawQuery = query.Encode()
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -35,7 +36,8 @@ func setupMockOIDC(t *testing.T) {
|
||||||
},
|
},
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
},
|
},
|
||||||
oidcProvider: provider,
|
endSessionURL: Must(url.Parse("http://mock-provider/logout")),
|
||||||
|
oidcProvider: provider,
|
||||||
oidcVerifier: provider.Verifier(&oidc.Config{
|
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||||
ClientID: "test-client",
|
ClientID: "test-client",
|
||||||
}),
|
}),
|
||||||
|
@ -148,14 +150,14 @@ func TestOIDCLoginHandler(t *testing.T) {
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Success - Redirects to provider",
|
name: "Success - Redirects to provider",
|
||||||
wantStatus: http.StatusTemporaryRedirect,
|
wantStatus: http.StatusFound,
|
||||||
wantRedirect: true,
|
wantRedirect: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, OIDCAuthInitPath, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
defaultAuth.(*OIDCProvider).HandleAuth(w, req)
|
defaultAuth.(*OIDCProvider).HandleAuth(w, req)
|
||||||
|
@ -194,7 +196,7 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
||||||
state: "valid-state",
|
state: "valid-state",
|
||||||
code: "valid-code",
|
code: "valid-code",
|
||||||
setupMocks: true,
|
setupMocks: true,
|
||||||
wantStatus: http.StatusTemporaryRedirect,
|
wantStatus: http.StatusFound,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Failure - Missing state",
|
name: "Failure - Missing state",
|
||||||
|
@ -396,7 +398,7 @@ func TestCheckToken(t *testing.T) {
|
||||||
"preferred_username": "user1",
|
"preferred_username": "user1",
|
||||||
"groups": []string{"group1"},
|
"groups": []string{"group1"},
|
||||||
},
|
},
|
||||||
wantErr: ErrInvalidToken,
|
wantErr: ErrInvalidOAuthToken,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Error - Server returns incorrect audience",
|
name: "Error - Server returns incorrect audience",
|
||||||
|
@ -407,7 +409,7 @@ func TestCheckToken(t *testing.T) {
|
||||||
"preferred_username": "user1",
|
"preferred_username": "user1",
|
||||||
"groups": []string{"group1"},
|
"groups": []string{"group1"},
|
||||||
},
|
},
|
||||||
wantErr: ErrInvalidToken,
|
wantErr: ErrInvalidOAuthToken,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Error - Server returns expired token",
|
name: "Error - Server returns expired token",
|
||||||
|
@ -418,7 +420,7 @@ func TestCheckToken(t *testing.T) {
|
||||||
"preferred_username": "user1",
|
"preferred_username": "user1",
|
||||||
"groups": []string{"group1"},
|
"groups": []string{"group1"},
|
||||||
},
|
},
|
||||||
wantErr: ErrInvalidToken,
|
wantErr: ErrInvalidOAuthToken,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
@ -448,3 +450,35 @@ func TestCheckToken(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestLogoutHandler(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
setupMockOIDC(t)
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodGet, OIDCLogoutPath, nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: CookieOauthToken,
|
||||||
|
Value: "test-token",
|
||||||
|
})
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: CookieOauthSessionToken,
|
||||||
|
Value: "test-session-token",
|
||||||
|
})
|
||||||
|
|
||||||
|
defaultAuth.(*OIDCProvider).LogoutHandler(w, req)
|
||||||
|
|
||||||
|
if got := w.Code; got != http.StatusFound {
|
||||||
|
t.Errorf("LogoutHandler() status = %v, want %v", got, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
if got := w.Header().Get("Location"); got == "" {
|
||||||
|
t.Error("LogoutHandler() missing redirect location")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(w.Header().Values("Set-Cookie")) != 2 {
|
||||||
|
t.Error("LogoutHandler() did not clear all cookies")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -4,4 +4,6 @@ import "net/http"
|
||||||
|
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
CheckToken(r *http.Request) error
|
CheckToken(r *http.Request) error
|
||||||
|
LoginHandler(w http.ResponseWriter, r *http.Request)
|
||||||
|
LogoutHandler(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
|
|
|
@ -76,7 +76,7 @@ func (auth *UserPassAuth) NewToken() (token string, err error) {
|
||||||
func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
||||||
jwtCookie, err := r.Cookie(auth.TokenCookieName())
|
jwtCookie, err := r.Cookie(auth.TokenCookieName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrMissingToken
|
return ErrMissingSessionToken
|
||||||
}
|
}
|
||||||
var claims UserPassClaims
|
var claims UserPassClaims
|
||||||
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
@ -90,7 +90,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
||||||
}
|
}
|
||||||
switch {
|
switch {
|
||||||
case !token.Valid:
|
case !token.Valid:
|
||||||
return ErrInvalidToken
|
return ErrInvalidSessionToken
|
||||||
case claims.Username != auth.username:
|
case claims.Username != auth.username:
|
||||||
return ErrUserNotAllowed.Subject(claims.Username)
|
return ErrUserNotAllowed.Subject(claims.Username)
|
||||||
case claims.ExpiresAt.Before(time.Now()):
|
case claims.ExpiresAt.Before(time.Now()):
|
||||||
|
@ -100,11 +100,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
|
func (auth *UserPassAuth) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
var creds struct {
|
var creds struct {
|
||||||
User string `json:"username"`
|
User string `json:"username"`
|
||||||
Pass string `json:"password"`
|
Pass string `json:"password"`
|
||||||
|
@ -127,9 +123,9 @@ func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Re
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *UserPassAuth) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
func (auth *UserPassAuth) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
clearTokenCookie(w, r, auth.TokenCookieName())
|
clearTokenCookie(w, r, auth.TokenCookieName())
|
||||||
auth.RedirectLoginPage(w, r)
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *UserPassAuth) validatePassword(user, pass string) error {
|
func (auth *UserPassAuth) validatePassword(user, pass string) error {
|
||||||
|
|
|
@ -98,7 +98,7 @@ func TestUserPassLoginCallbackHandler(t *testing.T) {
|
||||||
Host: "app.example.com",
|
Host: "app.example.com",
|
||||||
Body: io.NopCloser(bytes.NewReader(Must(json.Marshal(tt.creds)))),
|
Body: io.NopCloser(bytes.NewReader(Must(json.Marshal(tt.creds)))),
|
||||||
}
|
}
|
||||||
auth.LoginCallbackHandler(w, req)
|
auth.LoginHandler(w, req)
|
||||||
if tt.wantErr {
|
if tt.wantErr {
|
||||||
ExpectEqual(t, w.Code, http.StatusUnauthorized)
|
ExpectEqual(t, w.Code, http.StatusUnauthorized)
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -11,35 +10,34 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrMissingToken = gperr.New("missing token")
|
ErrMissingOAuthToken = gperr.New("missing oauth token")
|
||||||
ErrInvalidToken = gperr.New("invalid token")
|
ErrMissingSessionToken = gperr.New("missing session token")
|
||||||
ErrUserNotAllowed = gperr.New("user not allowed")
|
ErrInvalidOAuthToken = gperr.New("invalid oauth token")
|
||||||
|
ErrInvalidSessionToken = gperr.New("invalid session token")
|
||||||
|
ErrUserNotAllowed = gperr.New("user not allowed")
|
||||||
)
|
)
|
||||||
|
|
||||||
// cookieFQDN returns the fully qualified domain name of the request host
|
func requestHost(r *http.Request) string {
|
||||||
|
// check if it's from backend
|
||||||
|
switch r.Host {
|
||||||
|
case common.APIHTTPAddr:
|
||||||
|
// use XFH
|
||||||
|
return r.Header.Get("X-Forwarded-Host")
|
||||||
|
default:
|
||||||
|
return r.Host
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cookieDomain returns the fully qualified domain name of the request host
|
||||||
// with subdomain stripped.
|
// with subdomain stripped.
|
||||||
//
|
//
|
||||||
// If the request host does not have a subdomain,
|
// If the request host does not have a subdomain,
|
||||||
// an empty string is returned
|
// an empty string is returned
|
||||||
//
|
//
|
||||||
// "abc.example.com" -> "example.com"
|
// "abc.example.com" -> ".example.com" (cross subdomain)
|
||||||
// "example.com" -> ""
|
// "example.com" -> "" (same domain only)
|
||||||
func cookieFQDN(r *http.Request) string {
|
func cookieDomain(r *http.Request) string {
|
||||||
var host string
|
parts := strutils.SplitRune(requestHost(r), '.')
|
||||||
// check if it's from backend
|
|
||||||
switch r.Host {
|
|
||||||
case common.APIHTTPAddr:
|
|
||||||
// use XFH
|
|
||||||
host = r.Header.Get("X-Forwarded-Host")
|
|
||||||
default:
|
|
||||||
var err error
|
|
||||||
host, _, err = net.SplitHostPort(r.Host)
|
|
||||||
if err != nil {
|
|
||||||
host = r.Host
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
parts := strutils.SplitRune(host, '.')
|
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -52,7 +50,7 @@ func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string,
|
||||||
Name: name,
|
Name: name,
|
||||||
Value: value,
|
Value: value,
|
||||||
MaxAge: int(ttl.Seconds()),
|
MaxAge: int(ttl.Seconds()),
|
||||||
Domain: cookieFQDN(r),
|
Domain: cookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: common.APIJWTSecure,
|
Secure: common.APIJWTSecure,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
@ -65,7 +63,7 @@ func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
|
||||||
Name: name,
|
Name: name,
|
||||||
Value: "",
|
Value: "",
|
||||||
MaxAge: -1,
|
MaxAge: -1,
|
||||||
Domain: cookieFQDN(r),
|
Domain: cookieDomain(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: common.APIJWTSecure,
|
Secure: common.APIJWTSecure,
|
||||||
SameSite: http.SameSiteLaxMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
package jsonstore
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/puzpuzpuz/xsync/v3"
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
|
||||||
"github.com/yusing/go-proxy/internal/logging"
|
|
||||||
"github.com/yusing/go-proxy/internal/task"
|
|
||||||
"github.com/yusing/go-proxy/internal/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
type jsonStoreInternal struct{ *xsync.MapOf[string, any] }
|
|
||||||
type namespace string
|
|
||||||
|
|
||||||
var stores = make(map[namespace]jsonStoreInternal)
|
|
||||||
var storesMu sync.Mutex
|
|
||||||
var storesPath = filepath.Join(common.DataDir, "data.json")
|
|
||||||
|
|
||||||
func Initialize() {
|
|
||||||
if err := load(); err != nil {
|
|
||||||
logging.Error().Err(err).Msg("failed to load stores")
|
|
||||||
}
|
|
||||||
|
|
||||||
task.OnProgramExit("save_stores", func() {
|
|
||||||
if err := save(); err != nil {
|
|
||||||
logging.Error().Err(err).Msg("failed to save stores")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func load() error {
|
|
||||||
storesMu.Lock()
|
|
||||||
defer storesMu.Unlock()
|
|
||||||
if err := utils.LoadJSONIfExist(storesPath, &stores); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func save() error {
|
|
||||||
storesMu.Lock()
|
|
||||||
defer storesMu.Unlock()
|
|
||||||
return utils.SaveJSON(storesPath, &stores, 0o644)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s jsonStoreInternal) MarshalJSON() ([]byte, error) {
|
|
||||||
return json.Marshal(xsync.ToPlainMapOf(s.MapOf))
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s jsonStoreInternal) UnmarshalJSON(data []byte) error {
|
|
||||||
var tmp map[string]any
|
|
||||||
if err := json.Unmarshal(data, &tmp); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.MapOf = xsync.NewMapOf[string, any](xsync.WithPresize(len(tmp)))
|
|
||||||
for k, v := range tmp {
|
|
||||||
s.MapOf.Store(k, v)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
|
@ -1,47 +1,95 @@
|
||||||
package jsonstore
|
package jsonstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/puzpuzpuz/xsync/v3"
|
"github.com/puzpuzpuz/xsync/v3"
|
||||||
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
"github.com/yusing/go-proxy/internal/gperr"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
|
"github.com/yusing/go-proxy/internal/task"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type JSONStore[VT any] struct{ m jsonStoreInternal }
|
type namespace string
|
||||||
|
|
||||||
func NewStore[VT any](namespace namespace) JSONStore[VT] {
|
type Typed[VT any] struct {
|
||||||
storesMu.Lock()
|
*xsync.MapOf[string, VT]
|
||||||
defer storesMu.Unlock()
|
}
|
||||||
if s, ok := stores[namespace]; ok {
|
|
||||||
return JSONStore[VT]{s}
|
type storesMap struct {
|
||||||
|
sync.RWMutex
|
||||||
|
m map[namespace]any
|
||||||
|
}
|
||||||
|
|
||||||
|
var stores = storesMap{m: make(map[namespace]any)}
|
||||||
|
var storesPath = common.DataDir
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
if err := load(); err != nil {
|
||||||
|
logging.Error().Err(err).Msg("failed to load stores")
|
||||||
}
|
}
|
||||||
m := jsonStoreInternal{xsync.NewMapOf[string, any]()}
|
|
||||||
stores[namespace] = m
|
task.OnProgramExit("save_stores", func() {
|
||||||
return JSONStore[VT]{m}
|
if err := save(); err != nil {
|
||||||
|
logging.Error().Err(err).Msg("failed to save stores")
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s JSONStore[VT]) Load(key string) (_ VT, _ bool) {
|
func load() error {
|
||||||
value, ok := s.m.Load(key)
|
stores.Lock()
|
||||||
if !ok {
|
defer stores.Unlock()
|
||||||
return
|
errs := gperr.NewBuilder("failed to load data stores")
|
||||||
}
|
for ns, store := range stores.m {
|
||||||
return value.(VT), true
|
if err := utils.LoadJSONIfExist(filepath.Join(storesPath, string(ns)+".json"), &store); err != nil {
|
||||||
}
|
errs.Add(err)
|
||||||
|
|
||||||
func (s JSONStore[VT]) Has(key string) bool {
|
|
||||||
_, ok := s.m.Load(key)
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s JSONStore[VT]) Store(key string, value VT) {
|
|
||||||
s.m.Store(key, value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s JSONStore[VT]) Delete(key string) {
|
|
||||||
s.m.Delete(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s JSONStore[VT]) Iter(yield func(key string, value VT) bool) {
|
|
||||||
for k, v := range s.m.Range {
|
|
||||||
if !yield(k, v.(VT)) {
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return errs.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func save() error {
|
||||||
|
stores.Lock()
|
||||||
|
defer stores.Unlock()
|
||||||
|
errs := gperr.NewBuilder("failed to save data stores")
|
||||||
|
for ns, store := range stores.m {
|
||||||
|
if err := utils.SaveJSON(filepath.Join(common.DataDir, string(ns)+".json"), &store, 0o644); err != nil {
|
||||||
|
errs.Add(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errs.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func Store[VT any](namespace namespace) Typed[VT] {
|
||||||
|
stores.Lock()
|
||||||
|
defer stores.Unlock()
|
||||||
|
if s, ok := stores.m[namespace]; ok {
|
||||||
|
return s.(Typed[VT])
|
||||||
|
}
|
||||||
|
m := Typed[VT]{MapOf: xsync.NewMapOf[string, VT]()}
|
||||||
|
stores.m[namespace] = m
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Typed[VT]) MarshalJSON() ([]byte, error) {
|
||||||
|
tmp := make(map[string]VT, s.Size())
|
||||||
|
for k, v := range s.Range {
|
||||||
|
tmp[k] = v
|
||||||
|
}
|
||||||
|
return json.Marshal(tmp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s Typed[VT]) UnmarshalJSON(data []byte) error {
|
||||||
|
tmp := make(map[string]VT)
|
||||||
|
if err := json.Unmarshal(data, &tmp); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
s.MapOf = xsync.NewMapOf[string, VT](xsync.WithPresize(len(tmp)))
|
||||||
|
for k, v := range tmp {
|
||||||
|
s.MapOf.Store(k, v)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestNewJSON(t *testing.T) {
|
func TestNewJSON(t *testing.T) {
|
||||||
store := NewStore[string]("test")
|
store := Store[string]("test")
|
||||||
store.Store("a", "1")
|
store.Store("a", "1")
|
||||||
if v, _ := store.Load("a"); v != "1" {
|
if v, _ := store.Load("a"); v != "1" {
|
||||||
t.Fatal("expected 1, got", v)
|
t.Fatal("expected 1, got", v)
|
||||||
|
@ -16,16 +16,16 @@ func TestNewJSON(t *testing.T) {
|
||||||
func TestSaveLoad(t *testing.T) {
|
func TestSaveLoad(t *testing.T) {
|
||||||
tmpDir := t.TempDir()
|
tmpDir := t.TempDir()
|
||||||
storesPath = filepath.Join(tmpDir, "data.json")
|
storesPath = filepath.Join(tmpDir, "data.json")
|
||||||
store := NewStore[string]("test")
|
store := Store[string]("test")
|
||||||
store.Store("a", "1")
|
store.Store("a", "1")
|
||||||
if err := save(); err != nil {
|
if err := save(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
stores = nil
|
stores.m = nil
|
||||||
if err := load(); err != nil {
|
if err := load(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
store = NewStore[string]("test")
|
store = Store[string]("test")
|
||||||
if v, _ := store.Load("a"); v != "1" {
|
if v, _ := store.Load("a"); v != "1" {
|
||||||
t.Fatal("expected 1, got", v)
|
t.Fatal("expected 1, got", v)
|
||||||
}
|
}
|
||||||
|
|
|
@ -72,18 +72,13 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.URL.Path == auth.OIDCLogoutPath {
|
|
||||||
amw.auth.LogoutHandler(w, r)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
err := amw.auth.CheckToken(r)
|
err := amw.auth.CheckToken(r)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
case errors.Is(err, auth.ErrMissingToken):
|
case errors.Is(err, auth.ErrMissingOAuthToken):
|
||||||
amw.auth.HandleAuth(w, r)
|
amw.auth.HandleAuth(w, r)
|
||||||
default:
|
default:
|
||||||
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
|
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
|
||||||
|
|
Loading…
Add table
Reference in a new issue