fix: json store marshaling, api handler

- code clean up
- uncomment and simplify api auth handler
- fix redirect url for frontend
- proper redirect
This commit is contained in:
yusing 2025-04-24 04:47:42 +08:00
parent b815c6fd69
commit 7461344004
14 changed files with 234 additions and 213 deletions

View file

@ -14,7 +14,6 @@ import (
"github.com/yusing/go-proxy/internal/config"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/homepage"
"github.com/yusing/go-proxy/internal/jsonstore"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/logging/memlogger"
"github.com/yusing/go-proxy/internal/metrics/systeminfo"
@ -80,7 +79,6 @@ func main() {
logging.Info().Msgf("GoDoxy version %s", pkg.GetVersion())
logging.Trace().Msg("trace enabled")
parallel(
jsonstore.Initialize,
internal.InitIconListCache,
homepage.InitOverridesConfig,
favicon.InitIconCache,

View file

@ -98,21 +98,18 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
logging.Info().Msg("prometheus metrics enabled")
}
// defaultAuth := auth.GetDefaultAuth()
// if defaultAuth != nil {
// mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage)
// mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) {
// if err := defaultAuth.CheckToken(r); err != nil {
// http.Error(w, err.Error(), http.StatusUnauthorized)
// return
// }
// })
// mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler)
// mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutCallbackHandler)
// } else {
// mux.HandleFunc("GET", "/v1/auth/check", func(w http.ResponseWriter, r *http.Request) {
// w.WriteHeader(http.StatusOK)
// })
// }
defaultAuth := auth.GetDefaultAuth()
if defaultAuth == nil {
return mux
}
mux.HandleFunc("GET", "/v1/auth/check", auth.AuthCheckHandler)
mux.HandleFunc("GET", "/v1/auth/login", defaultAuth.LoginHandler)
mux.HandleFunc("GET", "/v1/auth/callback", defaultAuth.LoginHandler)
mux.HandleFunc("GET,POST", "/v1/auth/logout", defaultAuth.LogoutHandler)
switch authProvider := defaultAuth.(type) {
case *auth.OIDCProvider:
mux.HandleFunc("GET", "/v1/auth/postauth", authProvider.PostAuthCallbackHandler)
}
return mux
}

View file

@ -50,3 +50,11 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
}
return next
}
func AuthCheckHandler(w http.ResponseWriter, r *http.Request) {
if err := defaultAuth.CheckToken(r); err != nil {
http.Redirect(w, r, "/v1/auth/login", http.StatusFound)
} else {
w.WriteHeader(http.StatusOK)
}
}

View file

@ -3,6 +3,7 @@ package auth
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"net/http"
"time"
@ -33,18 +34,23 @@ type sessionClaims struct {
type sessionID string
var oauthRefreshTokens jsonstore.JSONStore[oauthRefreshToken]
var oauthRefreshTokens jsonstore.Typed[oauthRefreshToken]
var (
defaultRefreshTokenExpiry = 30 * 24 * time.Hour // 1 month
refreshBefore = 30 * time.Second
)
var (
errNoRefreshToken = errors.New("no refresh token")
ErrRefreshTokenFailure = errors.New("failed to refresh token")
)
const sessionTokenIssuer = "GoDoxy"
func init() {
if IsOIDCEnabled() {
oauthRefreshTokens = jsonstore.NewStore[oauthRefreshToken]("oauth_refresh_tokens")
oauthRefreshTokens = jsonstore.Store[oauthRefreshToken]("oauth_refresh_tokens")
}
}
@ -66,6 +72,9 @@ func newSession(username string, groups []string) Session {
}
}
// getOnceOAuthRefreshToken returns the refresh token for the given session.
//
// The token is removed from the store after retrieval.
func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
token, ok := oauthRefreshTokens.Load(string(claims.SessionID))
if !ok {
@ -82,15 +91,16 @@ func getOnceOAuthRefreshToken(claims *Session) (*oauthRefreshToken, bool) {
}
func storeOAuthRefreshToken(sessionID sessionID, username, token string) {
logging.Debug().Str("username", username).Msg("setting oauth refresh token")
oauthRefreshTokens.Store(string(sessionID), oauthRefreshToken{
Username: username,
RefreshToken: token,
Expiry: time.Now().Add(defaultRefreshTokenExpiry),
})
logging.Debug().Str("username", username).Msg("stored oauth refresh token")
}
func invalidateOAuthRefreshToken(sessionID sessionID) {
logging.Debug().Str("session_id", string(sessionID)).Msg("invalidating oauth refresh token")
oauthRefreshTokens.Delete(string(sessionID))
}
@ -125,26 +135,20 @@ func (auth *OIDCProvider) parseSessionJWT(sessionJWT string) (claims *sessionCla
return claims, sessionToken.Valid && claims.Issuer == sessionTokenIssuer, nil
}
func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request) error {
// check for session token
sessionCookie, err := r.Cookie(CookieOauthSessionToken)
if err != nil {
return ErrMissingToken
}
func (auth *OIDCProvider) TryRefreshToken(w http.ResponseWriter, r *http.Request, sessionJWT string) error {
// verify the session cookie
claims, valid, err := auth.parseSessionJWT(sessionCookie.Value)
claims, valid, err := auth.parseSessionJWT(sessionJWT)
if err != nil {
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
return fmt.Errorf("%w: %w", ErrInvalidSessionToken, err)
}
if !valid {
return ErrInvalidToken
return ErrInvalidSessionToken
}
// check if refresh is possible
refreshToken, ok := getOnceOAuthRefreshToken(&claims.Session)
if !ok {
return ErrMissingToken
return errNoRefreshToken
}
if !auth.checkAllowed(claims.Username, claims.Groups) {

View file

@ -39,20 +39,17 @@ type (
const (
CookieOauthState = "godoxy_oidc_state"
CookieOauthSessionID = "godoxy_session_id"
CookieOauthToken = "godoxy_token"
CookieOauthToken = "godoxy_oauth_token"
CookieOauthSessionToken = "godoxy_session_token"
)
const (
OIDCAuthCallbackPath = "/auth/callback"
OIDCAuthInitPath = "/auth/init"
OIDCPostAuthPath = "/auth/postauth"
OIDCLogoutPath = "/auth/logout"
)
var (
ErrMissingIDToken = errors.New("missing id_token")
ErrRefreshTokenFailure = errors.New("failed to refresh token")
)
var errMissingIDToken = errors.New("missing id_token field from oauth token")
// generateState generates a random string for OIDC state.
const oidcStateLength = 32
@ -118,14 +115,17 @@ func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
auth.allowedGroups = groups
}
// optRedirectPostAuth returns an oauth2 option that sets the "redirect_uri"
// parameter of the authorization URL to the post auth path of the current
// request host.
func optRedirectPostAuth(r *http.Request) oauth2.AuthCodeOption {
return oauth2.SetAuthURLParam("redirect_uri", "https://"+r.Host+OIDCPostAuthPath)
return oauth2.SetAuthURLParam("redirect_uri", "https://"+requestHost(r)+OIDCPostAuthPath)
}
func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Token) (string, *oidc.IDToken, error) {
idTokenJWT, ok := oauthToken.Extra("id_token").(string)
if !ok {
return "", nil, ErrMissingIDToken
return "", nil, errMissingIDToken
}
idToken, err := auth.oidcVerifier.Verify(ctx, idTokenJWT)
if err != nil {
@ -135,34 +135,37 @@ func (auth *OIDCProvider) getIdToken(ctx context.Context, oauthToken *oauth2.Tok
}
func (auth *OIDCProvider) HandleAuth(w http.ResponseWriter, r *http.Request) {
// check for session token
_, err := r.Cookie(CookieOauthSessionToken)
if err == nil {
err := auth.TryRefreshToken(w, r)
if err != nil {
logging.Debug().Err(err).Msg("failed to refresh token")
auth.LogoutHandler(w, r)
} else {
http.Redirect(w, r, "/", http.StatusFound)
}
return
}
switch r.URL.Path {
case OIDCAuthCallbackPath:
state := generateState()
setTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
// redirect user to Idp
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound)
case OIDCAuthInitPath:
auth.LoginHandler(w, r)
case OIDCPostAuthPath:
auth.PostAuthCallbackHandler(w, r)
case OIDCLogoutPath:
auth.LogoutHandler(w, r)
default:
http.Redirect(w, r, OIDCAuthCallbackPath, http.StatusFound)
http.Redirect(w, r, OIDCAuthInitPath, http.StatusFound)
}
}
func (auth *OIDCProvider) LoginHandler(w http.ResponseWriter, r *http.Request) {
// check for session token
sessionToken, err := r.Cookie(CookieOauthSessionToken)
if err == nil {
err = auth.TryRefreshToken(w, r, sessionToken.Value)
if err != nil {
logging.Debug().Err(err).Msg("failed to refresh token")
auth.clearCookie(w, r)
}
http.Redirect(w, r, "/", http.StatusFound)
return
}
state := generateState()
setTokenCookie(w, r, CookieOauthState, state, 300*time.Second)
// redirect user to Idp
http.Redirect(w, r, auth.oauthConfig.AuthCodeURL(state, optRedirectPostAuth(r)), http.StatusFound)
}
func parseClaims(idToken *oidc.IDToken) (*IDTokenClaims, error) {
var claim IDTokenClaims
if err := idToken.Claims(&claim); err != nil {
@ -188,17 +191,17 @@ func (auth *OIDCProvider) checkAllowed(user string, groups []string) bool {
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
tokenCookie, err := r.Cookie(CookieOauthToken)
if err != nil {
return ErrMissingToken
return ErrMissingOAuthToken
}
idToken, err := auth.oidcVerifier.Verify(r.Context(), tokenCookie.Value)
if err != nil {
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err)
}
claims, err := parseClaims(idToken)
if err != nil {
return fmt.Errorf("%w: %w", ErrInvalidToken, err)
return fmt.Errorf("%w: %w", ErrInvalidOAuthToken, err)
}
if !auth.checkAllowed(claims.Username, claims.Groups) {
@ -270,6 +273,7 @@ func (auth *OIDCProvider) LogoutHandler(w http.ResponseWriter, r *http.Request)
if auth.endSessionURL != nil && oauthToken != nil {
query := auth.endSessionURL.Query()
query.Set("id_token_hint", oauthToken.Value)
query.Set("post_logout_redirect_uri", "https://"+requestHost(r))
clone := *auth.endSessionURL
clone.RawQuery = query.Encode()

View file

@ -8,6 +8,7 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
@ -35,6 +36,7 @@ func setupMockOIDC(t *testing.T) {
},
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
},
endSessionURL: Must(url.Parse("http://mock-provider/logout")),
oidcProvider: provider,
oidcVerifier: provider.Verifier(&oidc.Config{
ClientID: "test-client",
@ -148,14 +150,14 @@ func TestOIDCLoginHandler(t *testing.T) {
}{
{
name: "Success - Redirects to provider",
wantStatus: http.StatusTemporaryRedirect,
wantStatus: http.StatusFound,
wantRedirect: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/", nil)
req := httptest.NewRequest(http.MethodGet, OIDCAuthInitPath, nil)
w := httptest.NewRecorder()
defaultAuth.(*OIDCProvider).HandleAuth(w, req)
@ -194,7 +196,7 @@ func TestOIDCCallbackHandler(t *testing.T) {
state: "valid-state",
code: "valid-code",
setupMocks: true,
wantStatus: http.StatusTemporaryRedirect,
wantStatus: http.StatusFound,
},
{
name: "Failure - Missing state",
@ -396,7 +398,7 @@ func TestCheckToken(t *testing.T) {
"preferred_username": "user1",
"groups": []string{"group1"},
},
wantErr: ErrInvalidToken,
wantErr: ErrInvalidOAuthToken,
},
{
name: "Error - Server returns incorrect audience",
@ -407,7 +409,7 @@ func TestCheckToken(t *testing.T) {
"preferred_username": "user1",
"groups": []string{"group1"},
},
wantErr: ErrInvalidToken,
wantErr: ErrInvalidOAuthToken,
},
{
name: "Error - Server returns expired token",
@ -418,7 +420,7 @@ func TestCheckToken(t *testing.T) {
"preferred_username": "user1",
"groups": []string{"group1"},
},
wantErr: ErrInvalidToken,
wantErr: ErrInvalidOAuthToken,
},
}
for _, tc := range tests {
@ -448,3 +450,35 @@ func TestCheckToken(t *testing.T) {
})
}
}
func TestLogoutHandler(t *testing.T) {
t.Helper()
setupMockOIDC(t)
req := httptest.NewRequest(http.MethodGet, OIDCLogoutPath, nil)
w := httptest.NewRecorder()
req.AddCookie(&http.Cookie{
Name: CookieOauthToken,
Value: "test-token",
})
req.AddCookie(&http.Cookie{
Name: CookieOauthSessionToken,
Value: "test-session-token",
})
defaultAuth.(*OIDCProvider).LogoutHandler(w, req)
if got := w.Code; got != http.StatusFound {
t.Errorf("LogoutHandler() status = %v, want %v", got, http.StatusFound)
}
if got := w.Header().Get("Location"); got == "" {
t.Error("LogoutHandler() missing redirect location")
}
if len(w.Header().Values("Set-Cookie")) != 2 {
t.Error("LogoutHandler() did not clear all cookies")
}
}

View file

@ -4,4 +4,6 @@ import "net/http"
type Provider interface {
CheckToken(r *http.Request) error
LoginHandler(w http.ResponseWriter, r *http.Request)
LogoutHandler(w http.ResponseWriter, r *http.Request)
}

View file

@ -76,7 +76,7 @@ func (auth *UserPassAuth) NewToken() (token string, err error) {
func (auth *UserPassAuth) CheckToken(r *http.Request) error {
jwtCookie, err := r.Cookie(auth.TokenCookieName())
if err != nil {
return ErrMissingToken
return ErrMissingSessionToken
}
var claims UserPassClaims
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
@ -90,7 +90,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error {
}
switch {
case !token.Valid:
return ErrInvalidToken
return ErrInvalidSessionToken
case claims.Username != auth.username:
return ErrUserNotAllowed.Subject(claims.Username)
case claims.ExpiresAt.Before(time.Now()):
@ -100,11 +100,7 @@ func (auth *UserPassAuth) CheckToken(r *http.Request) error {
return nil
}
func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
}
func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
func (auth *UserPassAuth) LoginHandler(w http.ResponseWriter, r *http.Request) {
var creds struct {
User string `json:"username"`
Pass string `json:"password"`
@ -127,9 +123,9 @@ func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Re
w.WriteHeader(http.StatusOK)
}
func (auth *UserPassAuth) LogoutCallbackHandler(w http.ResponseWriter, r *http.Request) {
func (auth *UserPassAuth) LogoutHandler(w http.ResponseWriter, r *http.Request) {
clearTokenCookie(w, r, auth.TokenCookieName())
auth.RedirectLoginPage(w, r)
http.Redirect(w, r, "/", http.StatusFound)
}
func (auth *UserPassAuth) validatePassword(user, pass string) error {

View file

@ -98,7 +98,7 @@ func TestUserPassLoginCallbackHandler(t *testing.T) {
Host: "app.example.com",
Body: io.NopCloser(bytes.NewReader(Must(json.Marshal(tt.creds)))),
}
auth.LoginCallbackHandler(w, req)
auth.LoginHandler(w, req)
if tt.wantErr {
ExpectEqual(t, w.Code, http.StatusUnauthorized)
} else {

View file

@ -1,7 +1,6 @@
package auth
import (
"net"
"net/http"
"time"
@ -11,35 +10,34 @@ import (
)
var (
ErrMissingToken = gperr.New("missing token")
ErrInvalidToken = gperr.New("invalid token")
ErrMissingOAuthToken = gperr.New("missing oauth token")
ErrMissingSessionToken = gperr.New("missing session token")
ErrInvalidOAuthToken = gperr.New("invalid oauth token")
ErrInvalidSessionToken = gperr.New("invalid session token")
ErrUserNotAllowed = gperr.New("user not allowed")
)
// cookieFQDN returns the fully qualified domain name of the request host
func requestHost(r *http.Request) string {
// check if it's from backend
switch r.Host {
case common.APIHTTPAddr:
// use XFH
return r.Header.Get("X-Forwarded-Host")
default:
return r.Host
}
}
// cookieDomain returns the fully qualified domain name of the request host
// with subdomain stripped.
//
// If the request host does not have a subdomain,
// an empty string is returned
//
// "abc.example.com" -> "example.com"
// "example.com" -> ""
func cookieFQDN(r *http.Request) string {
var host string
// check if it's from backend
switch r.Host {
case common.APIHTTPAddr:
// use XFH
host = r.Header.Get("X-Forwarded-Host")
default:
var err error
host, _, err = net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
}
parts := strutils.SplitRune(host, '.')
// "abc.example.com" -> ".example.com" (cross subdomain)
// "example.com" -> "" (same domain only)
func cookieDomain(r *http.Request) string {
parts := strutils.SplitRune(requestHost(r), '.')
if len(parts) < 2 {
return ""
}
@ -52,7 +50,7 @@ func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string,
Name: name,
Value: value,
MaxAge: int(ttl.Seconds()),
Domain: cookieFQDN(r),
Domain: cookieDomain(r),
HttpOnly: true,
Secure: common.APIJWTSecure,
SameSite: http.SameSiteLaxMode,
@ -65,7 +63,7 @@ func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
Name: name,
Value: "",
MaxAge: -1,
Domain: cookieFQDN(r),
Domain: cookieDomain(r),
HttpOnly: true,
Secure: common.APIJWTSecure,
SameSite: http.SameSiteLaxMode,

View file

@ -1,63 +0,0 @@
package jsonstore
import (
"encoding/json"
"path/filepath"
"sync"
"github.com/puzpuzpuz/xsync/v3"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
type jsonStoreInternal struct{ *xsync.MapOf[string, any] }
type namespace string
var stores = make(map[namespace]jsonStoreInternal)
var storesMu sync.Mutex
var storesPath = filepath.Join(common.DataDir, "data.json")
func Initialize() {
if err := load(); err != nil {
logging.Error().Err(err).Msg("failed to load stores")
}
task.OnProgramExit("save_stores", func() {
if err := save(); err != nil {
logging.Error().Err(err).Msg("failed to save stores")
}
})
}
func load() error {
storesMu.Lock()
defer storesMu.Unlock()
if err := utils.LoadJSONIfExist(storesPath, &stores); err != nil {
return err
}
return nil
}
func save() error {
storesMu.Lock()
defer storesMu.Unlock()
return utils.SaveJSON(storesPath, &stores, 0o644)
}
func (s jsonStoreInternal) MarshalJSON() ([]byte, error) {
return json.Marshal(xsync.ToPlainMapOf(s.MapOf))
}
func (s jsonStoreInternal) UnmarshalJSON(data []byte) error {
var tmp map[string]any
if err := json.Unmarshal(data, &tmp); err != nil {
return err
}
s.MapOf = xsync.NewMapOf[string, any](xsync.WithPresize(len(tmp)))
for k, v := range tmp {
s.MapOf.Store(k, v)
}
return nil
}

View file

@ -1,47 +1,95 @@
package jsonstore
import (
"encoding/json"
"path/filepath"
"sync"
"github.com/puzpuzpuz/xsync/v3"
"github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/gperr"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/task"
"github.com/yusing/go-proxy/internal/utils"
)
type JSONStore[VT any] struct{ m jsonStoreInternal }
type namespace string
func NewStore[VT any](namespace namespace) JSONStore[VT] {
storesMu.Lock()
defer storesMu.Unlock()
if s, ok := stores[namespace]; ok {
return JSONStore[VT]{s}
type Typed[VT any] struct {
*xsync.MapOf[string, VT]
}
type storesMap struct {
sync.RWMutex
m map[namespace]any
}
var stores = storesMap{m: make(map[namespace]any)}
var storesPath = common.DataDir
func init() {
if err := load(); err != nil {
logging.Error().Err(err).Msg("failed to load stores")
}
m := jsonStoreInternal{xsync.NewMapOf[string, any]()}
stores[namespace] = m
return JSONStore[VT]{m}
}
func (s JSONStore[VT]) Load(key string) (_ VT, _ bool) {
value, ok := s.m.Load(key)
if !ok {
return
task.OnProgramExit("save_stores", func() {
if err := save(); err != nil {
logging.Error().Err(err).Msg("failed to save stores")
}
return value.(VT), true
})
}
func (s JSONStore[VT]) Has(key string) bool {
_, ok := s.m.Load(key)
return ok
}
func (s JSONStore[VT]) Store(key string, value VT) {
s.m.Store(key, value)
}
func (s JSONStore[VT]) Delete(key string) {
s.m.Delete(key)
}
func (s JSONStore[VT]) Iter(yield func(key string, value VT) bool) {
for k, v := range s.m.Range {
if !yield(k, v.(VT)) {
return
func load() error {
stores.Lock()
defer stores.Unlock()
errs := gperr.NewBuilder("failed to load data stores")
for ns, store := range stores.m {
if err := utils.LoadJSONIfExist(filepath.Join(storesPath, string(ns)+".json"), &store); err != nil {
errs.Add(err)
}
}
return errs.Error()
}
func save() error {
stores.Lock()
defer stores.Unlock()
errs := gperr.NewBuilder("failed to save data stores")
for ns, store := range stores.m {
if err := utils.SaveJSON(filepath.Join(common.DataDir, string(ns)+".json"), &store, 0o644); err != nil {
errs.Add(err)
}
}
return errs.Error()
}
func Store[VT any](namespace namespace) Typed[VT] {
stores.Lock()
defer stores.Unlock()
if s, ok := stores.m[namespace]; ok {
return s.(Typed[VT])
}
m := Typed[VT]{MapOf: xsync.NewMapOf[string, VT]()}
stores.m[namespace] = m
return m
}
func (s Typed[VT]) MarshalJSON() ([]byte, error) {
tmp := make(map[string]VT, s.Size())
for k, v := range s.Range {
tmp[k] = v
}
return json.Marshal(tmp)
}
func (s Typed[VT]) UnmarshalJSON(data []byte) error {
tmp := make(map[string]VT)
if err := json.Unmarshal(data, &tmp); err != nil {
return err
}
s.MapOf = xsync.NewMapOf[string, VT](xsync.WithPresize(len(tmp)))
for k, v := range tmp {
s.MapOf.Store(k, v)
}
return nil
}

View file

@ -6,7 +6,7 @@ import (
)
func TestNewJSON(t *testing.T) {
store := NewStore[string]("test")
store := Store[string]("test")
store.Store("a", "1")
if v, _ := store.Load("a"); v != "1" {
t.Fatal("expected 1, got", v)
@ -16,16 +16,16 @@ func TestNewJSON(t *testing.T) {
func TestSaveLoad(t *testing.T) {
tmpDir := t.TempDir()
storesPath = filepath.Join(tmpDir, "data.json")
store := NewStore[string]("test")
store := Store[string]("test")
store.Store("a", "1")
if err := save(); err != nil {
t.Fatal(err)
}
stores = nil
stores.m = nil
if err := load(); err != nil {
t.Fatal(err)
}
store = NewStore[string]("test")
store = Store[string]("test")
if v, _ := store.Load("a"); v != "1" {
t.Fatal("expected 1, got", v)
}

View file

@ -72,18 +72,13 @@ func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proce
return false
}
if r.URL.Path == auth.OIDCLogoutPath {
amw.auth.LogoutHandler(w, r)
return false
}
err := amw.auth.CheckToken(r)
if err == nil {
return true
}
switch {
case errors.Is(err, auth.ErrMissingToken):
case errors.Is(err, auth.ErrMissingOAuthToken):
amw.auth.HandleAuth(w, r)
default:
auth.WriteBlockPage(w, http.StatusForbidden, err.Error(), auth.OIDCLogoutPath)