mirror of
https://github.com/yusing/godoxy.git
synced 2025-05-19 20:32:35 +02:00
fix: json store marshaling, api handler
- code clean up - uncomment and simplify api auth handler - fix redirect url for frontend - proper redirect
This commit is contained in:
parent
b815c6fd69
commit
7461344004
14 changed files with 234 additions and 213 deletions
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/yusing/go-proxy/internal/config"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/homepage"
|
||||
"github.com/yusing/go-proxy/internal/jsonstore"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/logging/memlogger"
|
||||
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
|
||||
|
@ -80,7 +79,6 @@ func main() {
|
|||
logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
|
||||
logging.Trace().Msg("trace enabled")
|
||||
parallel(
|
||||
jsonstore.Initialize,
|
||||
internal.InitIconListCache,
|
||||
homepage.InitOverridesConfig,
|
||||
favicon.InitIconCache,
|
||||
|
|
|
@ -98,21 +98,18 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
|
|||
logging.Info().Msg("prometheus metrics enabled")
|
||||
}
|
||||
|
||||
// defaultAuth := auth.GetDefaultAuth()
|
||||
// if defaultAuth != nil {
|
||||
// mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage)
|
||||
// mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) {
|
||||
// if err := defaultAuth.CheckToken(r); err != nil {
|
||||
// http.Error(w, err.Error(), http.StatusUnauthorized)
|
||||
// return
|
||||
// }
|
||||
// })
|
||||
// mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler)
|
||||
// mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutCallbackHandler)
|
||||
// } else {
|
||||
// mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) {
|
||||
// w.WriteHeader(http.StatusOK)
|
||||
// })
|
||||
// }
|
||||
defaultAuth := auth.GetDefaultAuth()
|
||||
if defaultAuth == nil {
|
||||
return mux
|
||||
}
|
||||
|
||||
mux.HandleFunc("GET", "/v1/auth/check", auth.AuthCheckHandler)
|
||||
mux.HandleFunc("GET", "/v1/auth/login", defaultAuth.LoginHandler)
|
||||
mux.HandleFunc("GET", "/v1/auth/callback", defaultAuth.LoginHandler)
|
||||
mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutHandler)
|
||||
switch authProvider := defaultAuth.(type) {
|
||||
case *auth.OIDCProvider:
|
||||
mux.HandleFunc("GET", "/v1/auth/postauth", authProvider.PostAuthCallbackHandler)
|
||||
}
|
||||
return mux
|
||||
}
|
||||
|
|
|
@ -50,3 +50,11 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
|||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func AuthCheckHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if err := defaultAuth.CheckToken(r); err != nil {
|
||||
http.Redirect(w, r, "/v1/auth/login", http.StatusFound)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package auth
|
|||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
@ -33,18 +34,23 @@ type sessionClaims struct {
|
|||
|
||||
type sessionID string
|
||||
|
||||
var oauthRefreshTokens jsonstore.JSONStore[oauthRefreshToken]
|
||||
var oauthRefreshTokens jsonstore.Typed[oauthRefreshToken]
|
||||
|
||||
var (
|
||||
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
|
||||
refreshBefore = 30 * time.Second
|
||||
)
|
||||
|
||||
var (
|
||||
errNoRefreshToken = errors.New("no refresh token")
|
||||
ErrRefreshTokenFailure = errors.New("failed to refresh token")
|
||||
)
|
||||
|
||||
const sessionTokenIssuer = "GoDoxy"
|
||||
|
||||
func init() {
|
||||
if IsOIDCEnabled() {
|
||||
oauthRefreshTokens = jsonstore.NewStore[oauthRefreshToken]("oauth_refresh_tokens")
|
||||
oauthRefreshTokens = jsonstore.Store[oauthRefreshToken]("oauth_refresh_tokens")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -66,6 +72,9 @@ func newSession(username string, groups []string) Session {
|
|||
}
|
||||
}
|
||||
|
||||
// getOnceOAuthRefreshToken returns the refresh token for the given session.
|
||||
//
|
||||
// The token is removed from the store after retrieval.
|
||||
func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
|
||||
token, ok := oauthRefreshTokens.Load(string(claims.SessionID))
|
||||
if !ok {
|
||||
|
@ -82,15 +91,16 @@ func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
|
|||
}
|
||||
|
||||
func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
|
||||
logging.Debug().Str("username", username).Msg("setting oauth refresh token")
|
||||
oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{
|
||||
Username: username,
|
||||
RefreshToken: token,
|
||||
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
|
||||
})
|
||||
logging.Debug().Str("username", username).Msg("stored oauth refresh token")
|
||||
}
|
||||
|
||||
func invalidateOAuthRefreshToken(sessionID sessionID) {
|
||||
logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
|
||||
oauthRefreshTokens.Delete(string(sessionID))
|
||||
}
|
||||
|
||||
|
@ -125,26 +135,20 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla
|
|||
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request) error {
|
||||
// check for session token
|
||||
sessionCookie, err := r.Cookie(CookieOauthSessionToken)
|
||||
if err != nil {
|
||||
return ErrMissingToken
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request, sessionJWT string) error {
|
||||
// verify the session cookie
|
||||
claims, valid, err := auth.parseSessionJWT(sessionCookie.Value)
|
||||
claims, valid, err := auth.parseSessionJWT(sessionJWT)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
||||
return fmt.Errorf("%w: %w", ErrInvalidSessionToken, err)
|
||||
}
|
||||
if !valid {
|
||||
return ErrInvalidToken
|
||||
return ErrInvalidSessionToken
|
||||
}
|
||||
|
||||
// check if refresh is possible
|
||||
refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session)
|
||||
if !ok {
|
||||
return ErrMissingToken
|
||||
return errNoRefreshToken
|
||||
}
|
||||
|
||||
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
||||
|
|
|
@ -39,20 +39,17 @@ type (
|
|||
const (
|
||||
CookieOauthState = "godoxy_oidc_state"
|
||||
CookieOauthSessionID = "godoxy_session_id"
|
||||
CookieOauthToken = "godoxy_token"
|
||||
CookieOauthToken = "godoxy_oauth_token"
|
||||
CookieOauthSessionToken = "godoxy_session_token"
|
||||
)
|
||||
|
||||
const (
|
||||
OIDCAuthCallbackPath = "/auth/callback"
|
||||
OIDCPostAuthPath = "/auth/postauth"
|
||||
OIDCLogoutPath = "/auth/logout"
|
||||
OIDCAuthInitPath = "/auth/init"
|
||||
OIDCPostAuthPath = "/auth/postauth"
|
||||
OIDCLogoutPath = "/auth/logout"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingIDToken = errors.New("missing id_token")
|
||||
ErrRefreshTokenFailure = errors.New("failed to refresh token")
|
||||
)
|
||||
var errMissingIDToken = errors.New("missing id_token field from oauth token")
|
||||
|
||||
// generateState generates a random string for OIDC state.
|
||||
const oidcStateLength = 32
|
||||
|
@ -118,14 +115,17 @@ func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
|
|||
auth.allowedGroups = groups
|
||||
}
|
||||
|
||||
// optRedirectPostAuth returns an oauth2 option that sets the "redirect_uri"
|
||||
// parameter of the authorization URL to the post auth path of the current
|
||||
// request host.
|
||||
func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
|
||||
return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath)
|
||||
return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
|
||||
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return "", nil, ErrMissingIDToken
|
||||
return "", nil, errMissingIDToken
|
||||
}
|
||||
idToken, err := auth.oidcVerifier.Verify(ctx, idTokenJWT)
|
||||
if err != nil {
|
||||
|
@ -135,34 +135,37 @@ func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Tok
|
|||
}
|
||||
|
||||
func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
|
||||
// check for session token
|
||||
_, err := r.Cookie(CookieOauthSessionToken)
|
||||
if err == nil {
|
||||
err := auth.TryRefreshToken(w, r)
|
||||
if err != nil {
|
||||
logging.Debug().Err(err).Msg("failed to refresh token")
|
||||
auth.LogoutHandler(w, r)
|
||||
} else {
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
switch r.URL.Path {
|
||||
case OIDCAuthCallbackPath:
|
||||
state := generateState()
|
||||
setTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
|
||||
// redirect user to Idp
|
||||
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound)
|
||||
case OIDCAuthInitPath:
|
||||
auth.LoginHandler(w, r)
|
||||
case OIDCPostAuthPath:
|
||||
auth.PostAuthCallbackHandler(w, r)
|
||||
case OIDCLogoutPath:
|
||||
auth.LogoutHandler(w, r)
|
||||
default:
|
||||
http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusFound)
|
||||
http.Redirect(w, r, OIDCAuthInitPath, http.StatusFound)
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// check for session token
|
||||
sessionToken, err := r.Cookie(CookieOauthSessionToken)
|
||||
if err == nil {
|
||||
err = auth.TryRefreshToken(w, r, sessionToken.Value)
|
||||
if err != nil {
|
||||
logging.Debug().Err(err).Msg("failed to refresh token")
|
||||
auth.clearCookie(w, r)
|
||||
}
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
return
|
||||
}
|
||||
|
||||
state := generateState()
|
||||
setTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
|
||||
// redirect user to Idp
|
||||
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound)
|
||||
}
|
||||
|
||||
func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
|
||||
var claim IDTokenClaims
|
||||
if err := idToken.Claims(&claim); err != nil {
|
||||
|
@ -188,17 +191,17 @@ func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
|
|||
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||
tokenCookie, err := r.Cookie(CookieOauthToken)
|
||||
if err != nil {
|
||||
return ErrMissingToken
|
||||
return ErrMissingOAuthToken
|
||||
}
|
||||
|
||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), tokenCookie.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
||||
return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err)
|
||||
}
|
||||
|
||||
claims, err := parseClaims(idToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
|
||||
return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err)
|
||||
}
|
||||
|
||||
if !auth.checkAllowed(claims.Username, claims.Groups) {
|
||||
|
@ -270,6 +273,7 @@ func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request)
|
|||
if auth.endSessionURL != nil && oauthToken != nil {
|
||||
query := auth.endSessionURL.Query()
|
||||
query.Set("id_token_hint", oauthToken.Value)
|
||||
query.Set("post_logout_redirect_uri", "https://"+requestHost(r))
|
||||
|
||||
clone := *auth.endSessionURL
|
||||
clone.RawQuery = query.Encode()
|
||||
|
|
|
@ -8,6 +8,7 @@ import (
|
|||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -35,7 +36,8 @@ func setupMockOIDC(t *testing.T) {
|
|||
},
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
},
|
||||
oidcProvider: provider,
|
||||
endSessionURL: Must(url.Parse("http://mock-provider/logout")),
|
||||
oidcProvider: provider,
|
||||
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: "test-client",
|
||||
}),
|
||||
|
@ -148,14 +150,14 @@ func TestOIDCLoginHandler(t *testing.T) {
|
|||
}{
|
||||
{
|
||||
name: "Success - Redirects to provider",
|
||||
wantStatus: http.StatusTemporaryRedirect,
|
||||
wantStatus: http.StatusFound,
|
||||
wantRedirect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req := httptest.NewRequest(http.MethodGet, OIDCAuthInitPath, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
defaultAuth.(*OIDCProvider).HandleAuth(w, req)
|
||||
|
@ -194,7 +196,7 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
|||
state: "valid-state",
|
||||
code: "valid-code",
|
||||
setupMocks: true,
|
||||
wantStatus: http.StatusTemporaryRedirect,
|
||||
wantStatus: http.StatusFound,
|
||||
},
|
||||
{
|
||||
name: "Failure - Missing state",
|
||||
|
@ -396,7 +398,7 @@ func TestCheckToken(t *testing.T) {
|
|||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrInvalidToken,
|
||||
wantErr: ErrInvalidOAuthToken,
|
||||
},
|
||||
{
|
||||
name: "Error - Server returns incorrect audience",
|
||||
|
@ -407,7 +409,7 @@ func TestCheckToken(t *testing.T) {
|
|||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrInvalidToken,
|
||||
wantErr: ErrInvalidOAuthToken,
|
||||
},
|
||||
{
|
||||
name: "Error - Server returns expired token",
|
||||
|
@ -418,7 +420,7 @@ func TestCheckToken(t *testing.T) {
|
|||
"preferred_username": "user1",
|
||||
"groups": []string{"group1"},
|
||||
},
|
||||
wantErr: ErrInvalidToken,
|
||||
wantErr: ErrInvalidOAuthToken,
|
||||
},
|
||||
}
|
||||
for _, tc := range tests {
|
||||
|
@ -448,3 +450,35 @@ func TestCheckToken(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogoutHandler(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
setupMockOIDC(t)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, OIDCLogoutPath, nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: CookieOauthToken,
|
||||
Value: "test-token",
|
||||
})
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: CookieOauthSessionToken,
|
||||
Value: "test-session-token",
|
||||
})
|
||||
|
||||
defaultAuth.(*OIDCProvider).LogoutHandler(w, req)
|
||||
|
||||
if got := w.Code; got != http.StatusFound {
|
||||
t.Errorf("LogoutHandler() status = %v, want %v", got, http.StatusFound)
|
||||
}
|
||||
|
||||
if got := w.Header().Get("Location"); got == "" {
|
||||
t.Error("LogoutHandler() missing redirect location")
|
||||
}
|
||||
|
||||
if len(w.Header().Values("Set-Cookie")) != 2 {
|
||||
t.Error("LogoutHandler() did not clear all cookies")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,4 +4,6 @@ import "net/http"
|
|||
|
||||
type Provider interface {
|
||||
CheckToken(r *http.Request) error
|
||||
LoginHandler(w http.ResponseWriter, r *http.Request)
|
||||
LogoutHandler(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
|
|
@ -76,7 +76,7 @@ func (auth *UserPassAuth) NewToken() (token string, err error) {
|
|||
func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
||||
jwtCookie, err := r.Cookie(auth.TokenCookieName())
|
||||
if err != nil {
|
||||
return ErrMissingToken
|
||||
return ErrMissingSessionToken
|
||||
}
|
||||
var claims UserPassClaims
|
||||
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
||||
|
@ -90,7 +90,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
|||
}
|
||||
switch {
|
||||
case !token.Valid:
|
||||
return ErrInvalidToken
|
||||
return ErrInvalidSessionToken
|
||||
case claims.Username != auth.username:
|
||||
return ErrUserNotAllowed.Subject(claims.Username)
|
||||
case claims.ExpiresAt.Before(time.Now()):
|
||||
|
@ -100,11 +100,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (auth *UserPassAuth) LoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var creds struct {
|
||||
User string `json:"username"`
|
||||
Pass string `json:"password"`
|
||||
|
@ -127,9 +123,9 @@ func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Re
|
|||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
func (auth *UserPassAuth) LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
clearTokenCookie(w, r, auth.TokenCookieName())
|
||||
auth.RedirectLoginPage(w, r)
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) validatePassword(user, pass string) error {
|
||||
|
|
|
@ -98,7 +98,7 @@ func TestUserPassLoginCallbackHandler(t *testing.T) {
|
|||
Host: "app.example.com",
|
||||
Body: io.NopCloser(bytes.NewReader(Must(json.Marshal(tt.creds)))),
|
||||
}
|
||||
auth.LoginCallbackHandler(w, req)
|
||||
auth.LoginHandler(w, req)
|
||||
if tt.wantErr {
|
||||
ExpectEqual(t, w.Code, http.StatusUnauthorized)
|
||||
} else {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
|
@ -11,35 +10,34 @@ import (
|
|||
)
|
||||
|
||||
var (
|
||||
ErrMissingToken = gperr.New("missing token")
|
||||
ErrInvalidToken = gperr.New("invalid token")
|
||||
ErrUserNotAllowed = gperr.New("user not allowed")
|
||||
ErrMissingOAuthToken = gperr.New("missing oauth token")
|
||||
ErrMissingSessionToken = gperr.New("missing session token")
|
||||
ErrInvalidOAuthToken = gperr.New("invalid oauth token")
|
||||
ErrInvalidSessionToken = gperr.New("invalid session token")
|
||||
ErrUserNotAllowed = gperr.New("user not allowed")
|
||||
)
|
||||
|
||||
// cookieFQDN returns the fully qualified domain name of the request host
|
||||
func requestHost(r *http.Request) string {
|
||||
// check if it's from backend
|
||||
switch r.Host {
|
||||
case common.APIHTTPAddr:
|
||||
// use XFH
|
||||
return r.Header.Get("X-Forwarded-Host")
|
||||
default:
|
||||
return r.Host
|
||||
}
|
||||
}
|
||||
|
||||
// cookieDomain returns the fully qualified domain name of the request host
|
||||
// with subdomain stripped.
|
||||
//
|
||||
// If the request host does not have a subdomain,
|
||||
// an empty string is returned
|
||||
//
|
||||
// "abc.example.com" -> "example.com"
|
||||
// "example.com" -> ""
|
||||
func cookieFQDN(r *http.Request) string {
|
||||
var host string
|
||||
// check if it's from backend
|
||||
switch r.Host {
|
||||
case common.APIHTTPAddr:
|
||||
// use XFH
|
||||
host = r.Header.Get("X-Forwarded-Host")
|
||||
default:
|
||||
var err error
|
||||
host, _, err = net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
}
|
||||
}
|
||||
|
||||
parts := strutils.SplitRune(host, '.')
|
||||
// "abc.example.com" -> ".example.com" (cross subdomain)
|
||||
// "example.com" -> "" (same domain only)
|
||||
func cookieDomain(r *http.Request) string {
|
||||
parts := strutils.SplitRune(requestHost(r), '.')
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
|
@ -52,7 +50,7 @@ func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string,
|
|||
Name: name,
|
||||
Value: value,
|
||||
MaxAge: int(ttl.Seconds()),
|
||||
Domain: cookieFQDN(r),
|
||||
Domain: cookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: common.APIJWTSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
|
@ -65,7 +63,7 @@ func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
|
|||
Name: name,
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
Domain: cookieFQDN(r),
|
||||
Domain: cookieDomain(r),
|
||||
HttpOnly: true,
|
||||
Secure: common.APIJWTSecure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
package jsonstore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type jsonStoreInternal struct{ *xsync.MapOf[string, any] }
|
||||
type namespace string
|
||||
|
||||
var stores = make(map[namespace]jsonStoreInternal)
|
||||
var storesMu sync.Mutex
|
||||
var storesPath = filepath.Join(common.DataDir, "data.json")
|
||||
|
||||
func Initialize() {
|
||||
if err := load(); err != nil {
|
||||
logging.Error().Err(err).Msg("failed to load stores")
|
||||
}
|
||||
|
||||
task.OnProgramExit("save_stores", func() {
|
||||
if err := save(); err != nil {
|
||||
logging.Error().Err(err).Msg("failed to save stores")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func load() error {
|
||||
storesMu.Lock()
|
||||
defer storesMu.Unlock()
|
||||
if err := utils.LoadJSONIfExist(storesPath, &stores); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func save() error {
|
||||
storesMu.Lock()
|
||||
defer storesMu.Unlock()
|
||||
return utils.SaveJSON(storesPath, &stores, 0o644)
|
||||
}
|
||||
|
||||
func (s jsonStoreInternal) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(xsync.ToPlainMapOf(s.MapOf))
|
||||
}
|
||||
|
||||
func (s jsonStoreInternal) UnmarshalJSON(data []byte) error {
|
||||
var tmp map[string]any
|
||||
if err := json.Unmarshal(data, &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.MapOf = xsync.NewMapOf[string, any](xsync.WithPresize(len(tmp)))
|
||||
for k, v := range tmp {
|
||||
s.MapOf.Store(k, v)
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -1,47 +1,95 @@
|
|||
package jsonstore
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/puzpuzpuz/xsync/v3"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/gperr"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/task"
|
||||
"github.com/yusing/go-proxy/internal/utils"
|
||||
)
|
||||
|
||||
type JSONStore[VT any] struct{ m jsonStoreInternal }
|
||||
type namespace string
|
||||
|
||||
func NewStore[VT any](namespace namespace) JSONStore[VT] {
|
||||
storesMu.Lock()
|
||||
defer storesMu.Unlock()
|
||||
if s, ok := stores[namespace]; ok {
|
||||
return JSONStore[VT]{s}
|
||||
type Typed[VT any] struct {
|
||||
*xsync.MapOf[string, VT]
|
||||
}
|
||||
|
||||
type storesMap struct {
|
||||
sync.RWMutex
|
||||
m map[namespace]any
|
||||
}
|
||||
|
||||
var stores = storesMap{m: make(map[namespace]any)}
|
||||
var storesPath = common.DataDir
|
||||
|
||||
func init() {
|
||||
if err := load(); err != nil {
|
||||
logging.Error().Err(err).Msg("failed to load stores")
|
||||
}
|
||||
m := jsonStoreInternal{xsync.NewMapOf[string, any]()}
|
||||
stores[namespace] = m
|
||||
return JSONStore[VT]{m}
|
||||
|
||||
task.OnProgramExit("save_stores", func() {
|
||||
if err := save(); err != nil {
|
||||
logging.Error().Err(err).Msg("failed to save stores")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (s JSONStore[VT]) Load(key string) (_ VT, _ bool) {
|
||||
value, ok := s.m.Load(key)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
return value.(VT), true
|
||||
}
|
||||
|
||||
func (s JSONStore[VT]) Has(key string) bool {
|
||||
_, ok := s.m.Load(key)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s JSONStore[VT]) Store(key string, value VT) {
|
||||
s.m.Store(key, value)
|
||||
}
|
||||
|
||||
func (s JSONStore[VT]) Delete(key string) {
|
||||
s.m.Delete(key)
|
||||
}
|
||||
|
||||
func (s JSONStore[VT]) Iter(yield func(key string, value VT) bool) {
|
||||
for k, v := range s.m.Range {
|
||||
if !yield(k, v.(VT)) {
|
||||
return
|
||||
func load() error {
|
||||
stores.Lock()
|
||||
defer stores.Unlock()
|
||||
errs := gperr.NewBuilder("failed to load data stores")
|
||||
for ns, store := range stores.m {
|
||||
if err := utils.LoadJSONIfExist(filepath.Join(storesPath, string(ns)+".json"), &store); err != nil {
|
||||
errs.Add(err)
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
func save() error {
|
||||
stores.Lock()
|
||||
defer stores.Unlock()
|
||||
errs := gperr.NewBuilder("failed to save data stores")
|
||||
for ns, store := range stores.m {
|
||||
if err := utils.SaveJSON(filepath.Join(common.DataDir, string(ns)+".json"), &store, 0o644); err != nil {
|
||||
errs.Add(err)
|
||||
}
|
||||
}
|
||||
return errs.Error()
|
||||
}
|
||||
|
||||
func Store[VT any](namespace namespace) Typed[VT] {
|
||||
stores.Lock()
|
||||
defer stores.Unlock()
|
||||
if s, ok := stores.m[namespace]; ok {
|
||||
return s.(Typed[VT])
|
||||
}
|
||||
m := Typed[VT]{MapOf: xsync.NewMapOf[string, VT]()}
|
||||
stores.m[namespace] = m
|
||||
return m
|
||||
}
|
||||
|
||||
func (s Typed[VT]) MarshalJSON() ([]byte, error) {
|
||||
tmp := make(map[string]VT, s.Size())
|
||||
for k, v := range s.Range {
|
||||
tmp[k] = v
|
||||
}
|
||||
return json.Marshal(tmp)
|
||||
}
|
||||
|
||||
func (s Typed[VT]) UnmarshalJSON(data []byte) error {
|
||||
tmp := make(map[string]VT)
|
||||
if err := json.Unmarshal(data, &tmp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.MapOf = xsync.NewMapOf[string, VT](xsync.WithPresize(len(tmp)))
|
||||
for k, v := range tmp {
|
||||
s.MapOf.Store(k, v)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
func TestNewJSON(t *testing.T) {
|
||||
store := NewStore[string]("test")
|
||||
store := Store[string]("test")
|
||||
store.Store("a", "1")
|
||||
if v, _ := store.Load("a"); v != "1" {
|
||||
t.Fatal("expected 1, got", v)
|
||||
|
@ -16,16 +16,16 @@ func TestNewJSON(t *testing.T) {
|
|||
func TestSaveLoad(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
storesPath = filepath.Join(tmpDir, "data.json")
|
||||
store := NewStore[string]("test")
|
||||
store := Store[string]("test")
|
||||
store.Store("a", "1")
|
||||
if err := save(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
stores = nil
|
||||
stores.m = nil
|
||||
if err := load(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
store = NewStore[string]("test")
|
||||
store = Store[string]("test")
|
||||
if v, _ := store.Load("a"); v != "1" {
|
||||
t.Fatal("expected 1, got", v)
|
||||
}
|
||||
|
|
|
@ -72,18 +72,13 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
|
|||
return false
|
||||
}
|
||||
|
||||
if r.URL.Path == auth.OIDCLogoutPath {
|
||||
amw.auth.LogoutHandler(w, r)
|
||||
return false
|
||||
}
|
||||
|
||||
err := amw.auth.CheckToken(r)
|
||||
if err == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
switch {
|
||||
case errors.Is(err, auth.ErrMissingToken):
|
||||
case errors.Is(err, auth.ErrMissingOAuthToken):
|
||||
amw.auth.HandleAuth(w, r)
|
||||
default:
|
||||
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)
|
||||
|
|
Loading…
Add table
Reference in a new issue