diff --git a/cmd/main.go b/cmd/main.go index 7768711..9306529 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -14,7 +14,6 @@ import ( "github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/gperr" "github.com/yusing/go-proxy/internal/homepage" - "github.com/yusing/go-proxy/internal/jsonstore" "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging/memlogger" "github.com/yusing/go-proxy/internal/metrics/systeminfo" @@ -80,7 +79,6 @@ func main() { logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion()) logging.Trace().Msg("trace enabled") parallel( - jsonstore.Initialize, internal.InitIconListCache, homepage.InitOverridesConfig, favicon.InitIconCache, diff --git a/internal/api/handler.go b/internal/api/handler.go index 70f12ea..756a320 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -98,21 +98,18 @@ func NewHandler(cfg config.ConfigInstance) http.Handler { logging.Info().Msg("prometheus metrics enabled") } - // defaultAuth := auth.GetDefaultAuth() - // if defaultAuth != nil { - // mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage) - // mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) { - // if err := defaultAuth.CheckToken(r); err != nil { - // http.Error(w, err.Error(), http.StatusUnauthorized) - // return - // } - // }) - // mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler) - // mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutCallbackHandler) - // } else { - // mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) { - // w.WriteHeader(http.StatusOK) - // }) - // } + defaultAuth := auth.GetDefaultAuth() + if defaultAuth == nil { + return mux + } + + mux.HandleFunc("GET", "/v1/auth/check", auth.AuthCheckHandler) + mux.HandleFunc("GET", "/v1/auth/login", defaultAuth.LoginHandler) + mux.HandleFunc("GET", "/v1/auth/callback", defaultAuth.LoginHandler) + mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutHandler) + switch authProvider := defaultAuth.(type) { + case *auth.OIDCProvider: + mux.HandleFunc("GET", "/v1/auth/postauth", authProvider.PostAuthCallbackHandler) + } return mux } diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 20afcbb..ba0ed00 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -50,3 +50,11 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc { } return next } + +func AuthCheckHandler(w http.ResponseWriter, r *http.Request) { + if err := defaultAuth.CheckToken(r); err != nil { + http.Redirect(w, r, "/v1/auth/login", http.StatusFound) + } else { + w.WriteHeader(http.StatusOK) + } +} diff --git a/internal/auth/oauth_refresh.go b/internal/auth/oauth_refresh.go index 6fce981..df46f8e 100644 --- a/internal/auth/oauth_refresh.go +++ b/internal/auth/oauth_refresh.go @@ -3,6 +3,7 @@ package auth import ( "crypto/rand" "encoding/base64" + "errors" "fmt" "net/http" "time" @@ -33,18 +34,23 @@ type sessionClaims struct { type sessionID string -var oauthRefreshTokens jsonstore.JSONStore[oauthRefreshToken] +var oauthRefreshTokens jsonstore.Typed[oauthRefreshToken] var ( defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month refreshBefore = 30 * time.Second ) +var ( + errNoRefreshToken = errors.New("no refresh token") + ErrRefreshTokenFailure = errors.New("failed to refresh token") +) + const sessionTokenIssuer = "GoDoxy" func init() { if IsOIDCEnabled() { - oauthRefreshTokens = jsonstore.NewStore[oauthRefreshToken]("oauth_refresh_tokens") + oauthRefreshTokens = jsonstore.Store[oauthRefreshToken]("oauth_refresh_tokens") } } @@ -66,6 +72,9 @@ func newSession(username string, groups []string) Session { } } +// getOnceOAuthRefreshToken returns the refresh token for the given session. +// +// The token is removed from the store after retrieval. func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) { token, ok := oauthRefreshTokens.Load(string(claims.SessionID)) if !ok { @@ -82,15 +91,16 @@ func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) { } func storeOAuthRefreshToken(sessionID sessionID, username, token string) { - logging.Debug().Str("username", username).Msg("setting oauth refresh token") oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{ Username: username, RefreshToken: token, Expiry: time.Now().Add(defaultRefreshTokenExpiry), }) + logging.Debug().Str("username", username).Msg("stored oauth refresh token") } func invalidateOAuthRefreshToken(sessionID sessionID) { + logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token") oauthRefreshTokens.Delete(string(sessionID)) } @@ -125,26 +135,20 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil } -func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request) error { - // check for session token - sessionCookie, err := r.Cookie(CookieOauthSessionToken) - if err != nil { - return ErrMissingToken - } - +func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request, sessionJWT string) error { // verify the session cookie - claims, valid, err := auth.parseSessionJWT(sessionCookie.Value) + claims, valid, err := auth.parseSessionJWT(sessionJWT) if err != nil { - return fmt.Errorf("%w: %w", ErrInvalidToken, err) + return fmt.Errorf("%w: %w", ErrInvalidSessionToken, err) } if !valid { - return ErrInvalidToken + return ErrInvalidSessionToken } // check if refresh is possible refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session) if !ok { - return ErrMissingToken + return errNoRefreshToken } if !auth.checkAllowed(claims.Username, claims.Groups) { diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 0233c73..f2b92d7 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -39,20 +39,17 @@ type ( const ( CookieOauthState = "godoxy_oidc_state" CookieOauthSessionID = "godoxy_session_id" - CookieOauthToken = "godoxy_token" + CookieOauthToken = "godoxy_oauth_token" CookieOauthSessionToken = "godoxy_session_token" ) const ( - OIDCAuthCallbackPath = "/auth/callback" - OIDCPostAuthPath = "/auth/postauth" - OIDCLogoutPath = "/auth/logout" + OIDCAuthInitPath = "/auth/init" + OIDCPostAuthPath = "/auth/postauth" + OIDCLogoutPath = "/auth/logout" ) -var ( - ErrMissingIDToken = errors.New("missing id_token") - ErrRefreshTokenFailure = errors.New("failed to refresh token") -) +var errMissingIDToken = errors.New("missing id_token field from oauth token") // generateState generates a random string for OIDC state. const oidcStateLength = 32 @@ -118,14 +115,17 @@ func (auth *OIDCProvider) SetAllowedGroups(groups []string) { auth.allowedGroups = groups } +// optRedirectPostAuth returns an oauth2 option that sets the "redirect_uri" +// parameter of the authorization URL to the post auth path of the current +// request host. func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption { - return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath) + return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath) } func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) { idTokenJWT, ok := oauthToken.Extra("id_token").(string) if !ok { - return "", nil, ErrMissingIDToken + return "", nil, errMissingIDToken } idToken, err := auth.oidcVerifier.Verify(ctx, idTokenJWT) if err != nil { @@ -135,34 +135,37 @@ func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Tok } func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) { - // check for session token - _, err := r.Cookie(CookieOauthSessionToken) - if err == nil { - err := auth.TryRefreshToken(w, r) - if err != nil { - logging.Debug().Err(err).Msg("failed to refresh token") - auth.LogoutHandler(w, r) - } else { - http.Redirect(w, r, "/", http.StatusFound) - } - return - } - switch r.URL.Path { - case OIDCAuthCallbackPath: - state := generateState() - setTokenCookie(w, r, CookieOauthState, state, 300*time.Second) - // redirect user to Idp - http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound) + case OIDCAuthInitPath: + auth.LoginHandler(w, r) case OIDCPostAuthPath: auth.PostAuthCallbackHandler(w, r) case OIDCLogoutPath: auth.LogoutHandler(w, r) default: - http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusFound) + http.Redirect(w, r, OIDCAuthInitPath, http.StatusFound) } } +func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) { + // check for session token + sessionToken, err := r.Cookie(CookieOauthSessionToken) + if err == nil { + err = auth.TryRefreshToken(w, r, sessionToken.Value) + if err != nil { + logging.Debug().Err(err).Msg("failed to refresh token") + auth.clearCookie(w, r) + } + http.Redirect(w, r, "/", http.StatusFound) + return + } + + state := generateState() + setTokenCookie(w, r, CookieOauthState, state, 300*time.Second) + // redirect user to Idp + http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound) +} + func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) { var claim IDTokenClaims if err := idToken.Claims(&claim); err != nil { @@ -188,17 +191,17 @@ func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool { func (auth *OIDCProvider) CheckToken(r *http.Request) error { tokenCookie, err := r.Cookie(CookieOauthToken) if err != nil { - return ErrMissingToken + return ErrMissingOAuthToken } idToken, err := auth.oidcVerifier.Verify(r.Context(), tokenCookie.Value) if err != nil { - return fmt.Errorf("%w: %w", ErrInvalidToken, err) + return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err) } claims, err := parseClaims(idToken) if err != nil { - return fmt.Errorf("%w: %w", ErrInvalidToken, err) + return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err) } if !auth.checkAllowed(claims.Username, claims.Groups) { @@ -270,6 +273,7 @@ func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request) if auth.endSessionURL != nil && oauthToken != nil { query := auth.endSessionURL.Query() query.Set("id_token_hint", oauthToken.Value) + query.Set("post_logout_redirect_uri", "https://"+requestHost(r)) clone := *auth.endSessionURL clone.RawQuery = query.Encode() diff --git a/internal/auth/oidc_test.go b/internal/auth/oidc_test.go index 27fe17a..c0791f3 100644 --- a/internal/auth/oidc_test.go +++ b/internal/auth/oidc_test.go @@ -8,6 +8,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "net/url" "testing" "time" @@ -35,7 +36,8 @@ func setupMockOIDC(t *testing.T) { }, Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, }, - oidcProvider: provider, + endSessionURL: Must(url.Parse("http://mock-provider/logout")), + oidcProvider: provider, oidcVerifier: provider.Verifier(&oidc.Config{ ClientID: "test-client", }), @@ -148,14 +150,14 @@ func TestOIDCLoginHandler(t *testing.T) { }{ { name: "Success - Redirects to provider", - wantStatus: http.StatusTemporaryRedirect, + wantStatus: http.StatusFound, wantRedirect: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - req := httptest.NewRequest(http.MethodGet, "/", nil) + req := httptest.NewRequest(http.MethodGet, OIDCAuthInitPath, nil) w := httptest.NewRecorder() defaultAuth.(*OIDCProvider).HandleAuth(w, req) @@ -194,7 +196,7 @@ func TestOIDCCallbackHandler(t *testing.T) { state: "valid-state", code: "valid-code", setupMocks: true, - wantStatus: http.StatusTemporaryRedirect, + wantStatus: http.StatusFound, }, { name: "Failure - Missing state", @@ -396,7 +398,7 @@ func TestCheckToken(t *testing.T) { "preferred_username": "user1", "groups": []string{"group1"}, }, - wantErr: ErrInvalidToken, + wantErr: ErrInvalidOAuthToken, }, { name: "Error - Server returns incorrect audience", @@ -407,7 +409,7 @@ func TestCheckToken(t *testing.T) { "preferred_username": "user1", "groups": []string{"group1"}, }, - wantErr: ErrInvalidToken, + wantErr: ErrInvalidOAuthToken, }, { name: "Error - Server returns expired token", @@ -418,7 +420,7 @@ func TestCheckToken(t *testing.T) { "preferred_username": "user1", "groups": []string{"group1"}, }, - wantErr: ErrInvalidToken, + wantErr: ErrInvalidOAuthToken, }, } for _, tc := range tests { @@ -448,3 +450,35 @@ func TestCheckToken(t *testing.T) { }) } } + +func TestLogoutHandler(t *testing.T) { + t.Helper() + + setupMockOIDC(t) + + req := httptest.NewRequest(http.MethodGet, OIDCLogoutPath, nil) + w := httptest.NewRecorder() + + req.AddCookie(&http.Cookie{ + Name: CookieOauthToken, + Value: "test-token", + }) + req.AddCookie(&http.Cookie{ + Name: CookieOauthSessionToken, + Value: "test-session-token", + }) + + defaultAuth.(*OIDCProvider).LogoutHandler(w, req) + + if got := w.Code; got != http.StatusFound { + t.Errorf("LogoutHandler() status = %v, want %v", got, http.StatusFound) + } + + if got := w.Header().Get("Location"); got == "" { + t.Error("LogoutHandler() missing redirect location") + } + + if len(w.Header().Values("Set-Cookie")) != 2 { + t.Error("LogoutHandler() did not clear all cookies") + } +} diff --git a/internal/auth/provider.go b/internal/auth/provider.go index f56b927..7bd9aa3 100644 --- a/internal/auth/provider.go +++ b/internal/auth/provider.go @@ -4,4 +4,6 @@ import "net/http" type Provider interface { CheckToken(r *http.Request) error + LoginHandler(w http.ResponseWriter, r *http.Request) + LogoutHandler(w http.ResponseWriter, r *http.Request) } diff --git a/internal/auth/userpass.go b/internal/auth/userpass.go index 34e1e81..80e86d2 100644 --- a/internal/auth/userpass.go +++ b/internal/auth/userpass.go @@ -76,7 +76,7 @@ func (auth *UserPassAuth) NewToken() (token string, err error) { func (auth *UserPassAuth) CheckToken(r *http.Request) error { jwtCookie, err := r.Cookie(auth.TokenCookieName()) if err != nil { - return ErrMissingToken + return ErrMissingSessionToken } var claims UserPassClaims token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { @@ -90,7 +90,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error { } switch { case !token.Valid: - return ErrInvalidToken + return ErrInvalidSessionToken case claims.Username != auth.username: return ErrUserNotAllowed.Subject(claims.Username) case claims.ExpiresAt.Before(time.Now()): @@ -100,11 +100,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error { return nil } -func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) { - http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) -} - -func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) { +func (auth *UserPassAuth) LoginHandler(w http.ResponseWriter, r *http.Request) { var creds struct { User string `json:"username"` Pass string `json:"password"` @@ -127,9 +123,9 @@ func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Re w.WriteHeader(http.StatusOK) } -func (auth *UserPassAuth) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) { +func (auth *UserPassAuth) LogoutHandler(w http.ResponseWriter, r *http.Request) { clearTokenCookie(w, r, auth.TokenCookieName()) - auth.RedirectLoginPage(w, r) + http.Redirect(w, r, "/", http.StatusFound) } func (auth *UserPassAuth) validatePassword(user, pass string) error { diff --git a/internal/auth/userpass_test.go b/internal/auth/userpass_test.go index 9a9fbc4..da3c841 100644 --- a/internal/auth/userpass_test.go +++ b/internal/auth/userpass_test.go @@ -98,7 +98,7 @@ func TestUserPassLoginCallbackHandler(t *testing.T) { Host: "app.example.com", Body: io.NopCloser(bytes.NewReader(Must(json.Marshal(tt.creds)))), } - auth.LoginCallbackHandler(w, req) + auth.LoginHandler(w, req) if tt.wantErr { ExpectEqual(t, w.Code, http.StatusUnauthorized) } else { diff --git a/internal/auth/utils.go b/internal/auth/utils.go index edefe33..ba0297e 100644 --- a/internal/auth/utils.go +++ b/internal/auth/utils.go @@ -1,7 +1,6 @@ package auth import ( - "net" "net/http" "time" @@ -11,35 +10,34 @@ import ( ) var ( - ErrMissingToken = gperr.New("missing token") - ErrInvalidToken = gperr.New("invalid token") - ErrUserNotAllowed = gperr.New("user not allowed") + ErrMissingOAuthToken = gperr.New("missing oauth token") + ErrMissingSessionToken = gperr.New("missing session token") + ErrInvalidOAuthToken = gperr.New("invalid oauth token") + ErrInvalidSessionToken = gperr.New("invalid session token") + ErrUserNotAllowed = gperr.New("user not allowed") ) -// cookieFQDN returns the fully qualified domain name of the request host +func requestHost(r *http.Request) string { + // check if it's from backend + switch r.Host { + case common.APIHTTPAddr: + // use XFH + return r.Header.Get("X-Forwarded-Host") + default: + return r.Host + } +} + +// cookieDomain returns the fully qualified domain name of the request host // with subdomain stripped. // // If the request host does not have a subdomain, // an empty string is returned // -// "abc.example.com" -> "example.com" -// "example.com" -> "" -func cookieFQDN(r *http.Request) string { - var host string - // check if it's from backend - switch r.Host { - case common.APIHTTPAddr: - // use XFH - host = r.Header.Get("X-Forwarded-Host") - default: - var err error - host, _, err = net.SplitHostPort(r.Host) - if err != nil { - host = r.Host - } - } - - parts := strutils.SplitRune(host, '.') +// "abc.example.com" -> ".example.com" (cross subdomain) +// "example.com" -> "" (same domain only) +func cookieDomain(r *http.Request) string { + parts := strutils.SplitRune(requestHost(r), '.') if len(parts) < 2 { return "" } @@ -52,7 +50,7 @@ func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string, Name: name, Value: value, MaxAge: int(ttl.Seconds()), - Domain: cookieFQDN(r), + Domain: cookieDomain(r), HttpOnly: true, Secure: common.APIJWTSecure, SameSite: http.SameSiteLaxMode, @@ -65,7 +63,7 @@ func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) { Name: name, Value: "", MaxAge: -1, - Domain: cookieFQDN(r), + Domain: cookieDomain(r), HttpOnly: true, Secure: common.APIJWTSecure, SameSite: http.SameSiteLaxMode, diff --git a/internal/jsonstore/internal.go b/internal/jsonstore/internal.go deleted file mode 100644 index e2ba875..0000000 --- a/internal/jsonstore/internal.go +++ /dev/null @@ -1,63 +0,0 @@ -package jsonstore - -import ( - "encoding/json" - "path/filepath" - "sync" - - "github.com/puzpuzpuz/xsync/v3" - "github.com/yusing/go-proxy/internal/common" - "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/task" - "github.com/yusing/go-proxy/internal/utils" -) - -type jsonStoreInternal struct{ *xsync.MapOf[string, any] } -type namespace string - -var stores = make(map[namespace]jsonStoreInternal) -var storesMu sync.Mutex -var storesPath = filepath.Join(common.DataDir, "data.json") - -func Initialize() { - if err := load(); err != nil { - logging.Error().Err(err).Msg("failed to load stores") - } - - task.OnProgramExit("save_stores", func() { - if err := save(); err != nil { - logging.Error().Err(err).Msg("failed to save stores") - } - }) -} - -func load() error { - storesMu.Lock() - defer storesMu.Unlock() - if err := utils.LoadJSONIfExist(storesPath, &stores); err != nil { - return err - } - return nil -} - -func save() error { - storesMu.Lock() - defer storesMu.Unlock() - return utils.SaveJSON(storesPath, &stores, 0o644) -} - -func (s jsonStoreInternal) MarshalJSON() ([]byte, error) { - return json.Marshal(xsync.ToPlainMapOf(s.MapOf)) -} - -func (s jsonStoreInternal) UnmarshalJSON(data []byte) error { - var tmp map[string]any - if err := json.Unmarshal(data, &tmp); err != nil { - return err - } - s.MapOf = xsync.NewMapOf[string, any](xsync.WithPresize(len(tmp))) - for k, v := range tmp { - s.MapOf.Store(k, v) - } - return nil -} diff --git a/internal/jsonstore/jsonstore.go b/internal/jsonstore/jsonstore.go index 944a10b..d5823e0 100644 --- a/internal/jsonstore/jsonstore.go +++ b/internal/jsonstore/jsonstore.go @@ -1,47 +1,95 @@ package jsonstore import ( + "encoding/json" + "path/filepath" + "sync" + "github.com/puzpuzpuz/xsync/v3" + "github.com/yusing/go-proxy/internal/common" + "github.com/yusing/go-proxy/internal/gperr" + "github.com/yusing/go-proxy/internal/logging" + "github.com/yusing/go-proxy/internal/task" + "github.com/yusing/go-proxy/internal/utils" ) -type JSONStore[VT any] struct{ m jsonStoreInternal } +type namespace string -func NewStore[VT any](namespace namespace) JSONStore[VT] { - storesMu.Lock() - defer storesMu.Unlock() - if s, ok := stores[namespace]; ok { - return JSONStore[VT]{s} +type Typed[VT any] struct { + *xsync.MapOf[string, VT] +} + +type storesMap struct { + sync.RWMutex + m map[namespace]any +} + +var stores = storesMap{m: make(map[namespace]any)} +var storesPath = common.DataDir + +func init() { + if err := load(); err != nil { + logging.Error().Err(err).Msg("failed to load stores") } - m := jsonStoreInternal{xsync.NewMapOf[string, any]()} - stores[namespace] = m - return JSONStore[VT]{m} + + task.OnProgramExit("save_stores", func() { + if err := save(); err != nil { + logging.Error().Err(err).Msg("failed to save stores") + } + }) } -func (s JSONStore[VT]) Load(key string) (_ VT, _ bool) { - value, ok := s.m.Load(key) - if !ok { - return - } - return value.(VT), true -} - -func (s JSONStore[VT]) Has(key string) bool { - _, ok := s.m.Load(key) - return ok -} - -func (s JSONStore[VT]) Store(key string, value VT) { - s.m.Store(key, value) -} - -func (s JSONStore[VT]) Delete(key string) { - s.m.Delete(key) -} - -func (s JSONStore[VT]) Iter(yield func(key string, value VT) bool) { - for k, v := range s.m.Range { - if !yield(k, v.(VT)) { - return +func load() error { + stores.Lock() + defer stores.Unlock() + errs := gperr.NewBuilder("failed to load data stores") + for ns, store := range stores.m { + if err := utils.LoadJSONIfExist(filepath.Join(storesPath, string(ns)+".json"), &store); err != nil { + errs.Add(err) } } + return errs.Error() +} + +func save() error { + stores.Lock() + defer stores.Unlock() + errs := gperr.NewBuilder("failed to save data stores") + for ns, store := range stores.m { + if err := utils.SaveJSON(filepath.Join(common.DataDir, string(ns)+".json"), &store, 0o644); err != nil { + errs.Add(err) + } + } + return errs.Error() +} + +func Store[VT any](namespace namespace) Typed[VT] { + stores.Lock() + defer stores.Unlock() + if s, ok := stores.m[namespace]; ok { + return s.(Typed[VT]) + } + m := Typed[VT]{MapOf: xsync.NewMapOf[string, VT]()} + stores.m[namespace] = m + return m +} + +func (s Typed[VT]) MarshalJSON() ([]byte, error) { + tmp := make(map[string]VT, s.Size()) + for k, v := range s.Range { + tmp[k] = v + } + return json.Marshal(tmp) +} + +func (s Typed[VT]) UnmarshalJSON(data []byte) error { + tmp := make(map[string]VT) + if err := json.Unmarshal(data, &tmp); err != nil { + return err + } + s.MapOf = xsync.NewMapOf[string, VT](xsync.WithPresize(len(tmp))) + for k, v := range tmp { + s.MapOf.Store(k, v) + } + return nil } diff --git a/internal/jsonstore/jsonstore_test.go b/internal/jsonstore/jsonstore_test.go index 7fa7e04..5fe28a2 100644 --- a/internal/jsonstore/jsonstore_test.go +++ b/internal/jsonstore/jsonstore_test.go @@ -6,7 +6,7 @@ import ( ) func TestNewJSON(t *testing.T) { - store := NewStore[string]("test") + store := Store[string]("test") store.Store("a", "1") if v, _ := store.Load("a"); v != "1" { t.Fatal("expected 1, got", v) @@ -16,16 +16,16 @@ func TestNewJSON(t *testing.T) { func TestSaveLoad(t *testing.T) { tmpDir := t.TempDir() storesPath = filepath.Join(tmpDir, "data.json") - store := NewStore[string]("test") + store := Store[string]("test") store.Store("a", "1") if err := save(); err != nil { t.Fatal(err) } - stores = nil + stores.m = nil if err := load(); err != nil { t.Fatal(err) } - store = NewStore[string]("test") + store = Store[string]("test") if v, _ := store.Load("a"); v != "1" { t.Fatal("expected 1, got", v) } diff --git a/internal/net/gphttp/middleware/oidc.go b/internal/net/gphttp/middleware/oidc.go index 3fc62b8..ec62fae 100644 --- a/internal/net/gphttp/middleware/oidc.go +++ b/internal/net/gphttp/middleware/oidc.go @@ -72,18 +72,13 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce return false } - if r.URL.Path == auth.OIDCLogoutPath { - amw.auth.LogoutHandler(w, r) - return false - } - err := amw.auth.CheckToken(r) if err == nil { return true } switch { - case errors.Is(err, auth.ErrMissingToken): + case errors.Is(err, auth.ErrMissingOAuthToken): amw.auth.HandleAuth(w, r) default: auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)