From b44c8586cce7165e53db406b8f320811986fbd60 Mon Sep 17 00:00:00 2001 From: yusing Date: Tue, 14 Jan 2025 05:29:13 +0800 Subject: [PATCH] fix tests and callbackURL --- internal/api/v1/auth/auth.go | 2 +- internal/api/v1/auth/oidc.go | 23 +++-- internal/api/v1/auth/oidc_test.go | 13 ++- internal/api/v1/auth/provider.go | 2 +- internal/api/v1/auth/userpass.go | 8 +- internal/api/v1/auth/userpass_test.go | 116 ++++++++++++++++++++++++++ internal/net/http/middleware/oidc.go | 15 ++-- 7 files changed, 152 insertions(+), 27 deletions(-) create mode 100644 internal/api/v1/auth/userpass_test.go diff --git a/internal/api/v1/auth/auth.go b/internal/api/v1/auth/auth.go index 1a32bab..ed0060e 100644 --- a/internal/api/v1/auth/auth.go +++ b/internal/api/v1/auth/auth.go @@ -43,7 +43,7 @@ func IsOIDCEnabled() bool { func RequireAuth(next http.HandlerFunc) http.HandlerFunc { if IsEnabled() { return func(w http.ResponseWriter, r *http.Request) { - if err := defaultAuth.CheckToken(w, r); err != nil { + if err := defaultAuth.CheckToken(r); err != nil { U.RespondError(w, err, http.StatusUnauthorized) } else { next(w, r) diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index f6694d1..f61c445 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -23,11 +23,16 @@ type OIDCProvider struct { oidcProvider *oidc.Provider oidcVerifier *oidc.IDTokenVerifier allowedUsers []string - overrideHost bool + isMiddleware bool } const CookieOauthState = "godoxy_oidc_state" +const ( + OIDCMiddlewareCallbackPath = "/auth/callback" + OIDCLogoutPath = "/auth/logout" +) + func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) { if len(allowedUsers) == 0 { return nil, errors.New("OIDC allowed users must not be empty") @@ -69,15 +74,18 @@ func (auth *OIDCProvider) TokenCookieName() string { return "godoxy_oidc_token" } -func (auth *OIDCProvider) SetOverrideHostEnabled(enabled bool) { - auth.overrideHost = enabled +func (auth *OIDCProvider) SetIsMiddleware(enabled bool) { + auth.isMiddleware = enabled + if auth.isMiddleware { + auth.oauthConfig.RedirectURL = OIDCMiddlewareCallbackPath + } } func (auth *OIDCProvider) SetAllowedUsers(users []string) { auth.allowedUsers = users } -func (auth *OIDCProvider) CheckToken(w http.ResponseWriter, r *http.Request) error { +func (auth *OIDCProvider) CheckToken(r *http.Request) error { token, err := r.Cookie(auth.TokenCookieName()) if err != nil { return ErrMissingToken @@ -137,7 +145,7 @@ func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Reque }) redirURL := auth.oauthConfig.AuthCodeURL(state) - if auth.overrideHost { + if auth.isMiddleware { u, err := r.URL.Parse(redirURL) if err != nil { U.HandleErr(w, r, err, http.StatusInternalServerError) @@ -165,12 +173,13 @@ func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Re return } - if r.URL.Query().Get("state") != state.Value { + query := r.URL.Query() + if query.Get("state") != state.Value { U.HandleErr(w, r, E.New("invalid oauth state"), http.StatusBadRequest) return } - code := r.URL.Query().Get("code") + code := query.Get("code") oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code) if err != nil { U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError) diff --git a/internal/api/v1/auth/oidc_test.go b/internal/api/v1/auth/oidc_test.go index cd29acd..c83579e 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/api/v1/auth/oidc_test.go @@ -8,7 +8,10 @@ import ( "github.com/coreos/go-oidc/v3/oidc" "github.com/yusing/go-proxy/internal/common" + E "github.com/yusing/go-proxy/internal/error" "golang.org/x/oauth2" + + . "github.com/yusing/go-proxy/internal/utils/testing" ) // setupMockOIDC configures mock OIDC provider for testing. @@ -130,10 +133,12 @@ func TestOIDCCallbackHandler(t *testing.T) { } if tt.wantStatus == http.StatusTemporaryRedirect { - cookie := w.Header().Get("Set-Cookie") - if cookie == "" { - t.Error("OIDCCallbackHandler() missing token cookie") - } + setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie"))) + ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName()) + ExpectTrue(t, setCookie.Value != "") + ExpectEqual(t, setCookie.Path, "/") + ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode) + ExpectEqual(t, setCookie.HttpOnly, true) } }) } diff --git a/internal/api/v1/auth/provider.go b/internal/api/v1/auth/provider.go index 69f24b9..8ea4d32 100644 --- a/internal/api/v1/auth/provider.go +++ b/internal/api/v1/auth/provider.go @@ -6,7 +6,7 @@ import ( type Provider interface { TokenCookieName() string - CheckToken(w http.ResponseWriter, r *http.Request) error + CheckToken(r *http.Request) error RedirectLoginPage(w http.ResponseWriter, r *http.Request) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) } diff --git a/internal/api/v1/auth/userpass.go b/internal/api/v1/auth/userpass.go index 05b2887..ae80c1c 100644 --- a/internal/api/v1/auth/userpass.go +++ b/internal/api/v1/auth/userpass.go @@ -58,7 +58,7 @@ func (auth *UserPassAuth) TokenCookieName() string { return "godoxy_token" } -func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (token string, err error) { +func (auth *UserPassAuth) NewToken() (token string, err error) { claim := &UserPassClaims{ Username: auth.username, RegisteredClaims: jwt.RegisteredClaims{ @@ -73,7 +73,7 @@ func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (t return token, nil } -func (auth *UserPassAuth) CheckToken(w http.ResponseWriter, r *http.Request) error { +func (auth *UserPassAuth) CheckToken(r *http.Request) error { jwtCookie, err := r.Cookie(auth.TokenCookieName()) if err != nil { return ErrMissingToken @@ -118,7 +118,7 @@ func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Re U.HandleErr(w, r, err, http.StatusUnauthorized) return } - token, err := auth.CreateToken(w, r) + token, err := auth.NewToken() if err != nil { U.HandleErr(w, r, err, http.StatusInternalServerError) return @@ -132,7 +132,7 @@ func (auth *UserPassAuth) validatePassword(user, pass string) error { return ErrInvalidUsername.Subject(user) } if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil { - return ErrInvalidPassword.Subject(pass) + return ErrInvalidPassword.With(err).Subject(pass) } return nil } diff --git a/internal/api/v1/auth/userpass_test.go b/internal/api/v1/auth/userpass_test.go new file mode 100644 index 0000000..c43360e --- /dev/null +++ b/internal/api/v1/auth/userpass_test.go @@ -0,0 +1,116 @@ +package auth + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + E "github.com/yusing/go-proxy/internal/error" + . "github.com/yusing/go-proxy/internal/utils/testing" + "golang.org/x/crypto/bcrypt" +) + +func newMockUserPassAuth() *UserPassAuth { + return &UserPassAuth{ + username: "username", + pwdHash: E.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)), + secret: []byte("abcdefghijklmnopqrstuvwxyz"), + tokenTTL: time.Hour, + } +} + +func TestUserPassValidateCredentials(t *testing.T) { + auth := newMockUserPassAuth() + err := auth.validatePassword("username", "password") + ExpectNoError(t, err) + err = auth.validatePassword("username", "wrong-password") + ExpectError(t, ErrInvalidPassword, err) + err = auth.validatePassword("wrong-username", "password") + ExpectError(t, ErrInvalidUsername, err) +} + +func TestUserPassCheckToken(t *testing.T) { + auth := newMockUserPassAuth() + token, err := auth.NewToken() + ExpectNoError(t, err) + tests := []struct { + token string + wantErr bool + }{ + { + token: token, + wantErr: false, + }, + { + token: "invalid-token", + wantErr: true, + }, + { + token: "", + wantErr: true, + }, + } + for _, tt := range tests { + req := &http.Request{Header: http.Header{}} + if tt.token != "" { + req.Header.Set("Cookie", auth.TokenCookieName()+"="+tt.token) + } + err = auth.CheckToken(req) + if tt.wantErr { + ExpectTrue(t, err != nil) + } else { + ExpectNoError(t, err) + } + } +} + +func TestUserPassLoginCallbackHandler(t *testing.T) { + type cred struct { + User string `json:"username"` + Pass string `json:"password"` + } + auth := newMockUserPassAuth() + tests := []struct { + creds cred + wantErr bool + }{ + { + creds: cred{ + User: "username", + Pass: "password", + }, + wantErr: false, + }, + { + creds: cred{ + User: "username", + Pass: "wrong-password", + }, + wantErr: true, + }, + } + for _, tt := range tests { + w := httptest.NewRecorder() + req := &http.Request{ + Host: "app.example.com", + Body: io.NopCloser(bytes.NewReader(E.Must(json.Marshal(tt.creds)))), + } + auth.LoginCallbackHandler(w, req) + if tt.wantErr { + ExpectEqual(t, w.Code, http.StatusUnauthorized) + } else { + setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie"))) + ExpectTrue(t, setCookie.Name == auth.TokenCookieName()) + ExpectTrue(t, setCookie.Value != "") + ExpectEqual(t, setCookie.Domain, "example.com") + ExpectEqual(t, setCookie.Path, "/") + ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode) + ExpectEqual(t, setCookie.HttpOnly, true) + ExpectEqual(t, w.Code, http.StatusOK) + } + } +} diff --git a/internal/net/http/middleware/oidc.go b/internal/net/http/middleware/oidc.go index 464a662..f732d81 100644 --- a/internal/net/http/middleware/oidc.go +++ b/internal/net/http/middleware/oidc.go @@ -17,11 +17,6 @@ type oidcMiddleware struct { var OIDC = NewMiddleware[oidcMiddleware]() -const ( - OIDCMiddlewareCallbackPath = "/auth/callback" - OIDCLogoutPath = "/auth/logout" -) - func (amw *oidcMiddleware) finalize() error { if !auth.IsOIDCEnabled() { return E.New("OIDC not enabled but Auth middleware is used") @@ -31,14 +26,14 @@ func (amw *oidcMiddleware) finalize() error { return err } - authProvider.SetOverrideHostEnabled(true) + authProvider.SetIsMiddleware(true) if len(amw.AllowedUsers) > 0 { authProvider.SetAllowedUsers(amw.AllowedUsers) } amw.authMux = http.NewServeMux() - amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler) - amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) { + amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler) + amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) { http.Error(w, "Unauthorized", http.StatusUnauthorized) }) amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage) @@ -48,11 +43,11 @@ func (amw *oidcMiddleware) finalize() error { } func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { - if err := amw.auth.CheckToken(w, r); err != nil { + if err := amw.auth.CheckToken(r); err != nil { amw.authMux.ServeHTTP(w, r) return false } - if r.URL.Path == OIDCLogoutPath { + if r.URL.Path == auth.OIDCLogoutPath { amw.logoutHandler(w, r) return false }