fix tests and callbackURL

This commit is contained in:
yusing 2025-01-14 05:29:13 +08:00
parent c5e0ac6f38
commit b44c8586cc
7 changed files with 152 additions and 27 deletions

View file

@ -43,7 +43,7 @@ func IsOIDCEnabled() bool {
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
if IsEnabled() {
return func(w http.ResponseWriter, r *http.Request) {
if err := defaultAuth.CheckToken(w, r); err != nil {
if err := defaultAuth.CheckToken(r); err != nil {
U.RespondError(w, err, http.StatusUnauthorized)
} else {
next(w, r)

View file

@ -23,11 +23,16 @@ type OIDCProvider struct {
oidcProvider *oidc.Provider
oidcVerifier *oidc.IDTokenVerifier
allowedUsers []string
overrideHost bool
isMiddleware bool
}
const CookieOauthState = "godoxy_oidc_state"
const (
OIDCMiddlewareCallbackPath = "/auth/callback"
OIDCLogoutPath = "/auth/logout"
)
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) {
if len(allowedUsers) == 0 {
return nil, errors.New("OIDC allowed users must not be empty")
@ -69,15 +74,18 @@ func (auth *OIDCProvider) TokenCookieName() string {
return "godoxy_oidc_token"
}
func (auth *OIDCProvider) SetOverrideHostEnabled(enabled bool) {
auth.overrideHost = enabled
func (auth *OIDCProvider) SetIsMiddleware(enabled bool) {
auth.isMiddleware = enabled
if auth.isMiddleware {
auth.oauthConfig.RedirectURL = OIDCMiddlewareCallbackPath
}
}
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
auth.allowedUsers = users
}
func (auth *OIDCProvider) CheckToken(w http.ResponseWriter, r *http.Request) error {
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
token, err := r.Cookie(auth.TokenCookieName())
if err != nil {
return ErrMissingToken
@ -137,7 +145,7 @@ func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Reque
})
redirURL := auth.oauthConfig.AuthCodeURL(state)
if auth.overrideHost {
if auth.isMiddleware {
u, err := r.URL.Parse(redirURL)
if err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
@ -165,12 +173,13 @@ func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Re
return
}
if r.URL.Query().Get("state") != state.Value {
query := r.URL.Query()
if query.Get("state") != state.Value {
U.HandleErr(w, r, E.New("invalid oauth state"), http.StatusBadRequest)
return
}
code := r.URL.Query().Get("code")
code := query.Get("code")
oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code)
if err != nil {
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)

View file

@ -8,7 +8,10 @@ import (
"github.com/coreos/go-oidc/v3/oidc"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"golang.org/x/oauth2"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
// setupMockOIDC configures mock OIDC provider for testing.
@ -130,10 +133,12 @@ func TestOIDCCallbackHandler(t *testing.T) {
}
if tt.wantStatus == http.StatusTemporaryRedirect {
cookie := w.Header().Get("Set-Cookie")
if cookie == "" {
t.Error("OIDCCallbackHandler() missing token cookie")
}
setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName())
ExpectTrue(t, setCookie.Value != "")
ExpectEqual(t, setCookie.Path, "/")
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
ExpectEqual(t, setCookie.HttpOnly, true)
}
})
}

View file

@ -6,7 +6,7 @@ import (
type Provider interface {
TokenCookieName() string
CheckToken(w http.ResponseWriter, r *http.Request) error
CheckToken(r *http.Request) error
RedirectLoginPage(w http.ResponseWriter, r *http.Request)
LoginCallbackHandler(w http.ResponseWriter, r *http.Request)
}

View file

@ -58,7 +58,7 @@ func (auth *UserPassAuth) TokenCookieName() string {
return "godoxy_token"
}
func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (token string, err error) {
func (auth *UserPassAuth) NewToken() (token string, err error) {
claim := &UserPassClaims{
Username: auth.username,
RegisteredClaims: jwt.RegisteredClaims{
@ -73,7 +73,7 @@ func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (t
return token, nil
}
func (auth *UserPassAuth) CheckToken(w http.ResponseWriter, r *http.Request) error {
func (auth *UserPassAuth) CheckToken(r *http.Request) error {
jwtCookie, err := r.Cookie(auth.TokenCookieName())
if err != nil {
return ErrMissingToken
@ -118,7 +118,7 @@ func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Re
U.HandleErr(w, r, err, http.StatusUnauthorized)
return
}
token, err := auth.CreateToken(w, r)
token, err := auth.NewToken()
if err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
return
@ -132,7 +132,7 @@ func (auth *UserPassAuth) validatePassword(user, pass string) error {
return ErrInvalidUsername.Subject(user)
}
if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil {
return ErrInvalidPassword.Subject(pass)
return ErrInvalidPassword.With(err).Subject(pass)
}
return nil
}

View file

@ -0,0 +1,116 @@
package auth
import (
"bytes"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
E "github.com/yusing/go-proxy/internal/error"
. "github.com/yusing/go-proxy/internal/utils/testing"
"golang.org/x/crypto/bcrypt"
)
func newMockUserPassAuth() *UserPassAuth {
return &UserPassAuth{
username: "username",
pwdHash: E.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)),
secret: []byte("abcdefghijklmnopqrstuvwxyz"),
tokenTTL: time.Hour,
}
}
func TestUserPassValidateCredentials(t *testing.T) {
auth := newMockUserPassAuth()
err := auth.validatePassword("username", "password")
ExpectNoError(t, err)
err = auth.validatePassword("username", "wrong-password")
ExpectError(t, ErrInvalidPassword, err)
err = auth.validatePassword("wrong-username", "password")
ExpectError(t, ErrInvalidUsername, err)
}
func TestUserPassCheckToken(t *testing.T) {
auth := newMockUserPassAuth()
token, err := auth.NewToken()
ExpectNoError(t, err)
tests := []struct {
token string
wantErr bool
}{
{
token: token,
wantErr: false,
},
{
token: "invalid-token",
wantErr: true,
},
{
token: "",
wantErr: true,
},
}
for _, tt := range tests {
req := &http.Request{Header: http.Header{}}
if tt.token != "" {
req.Header.Set("Cookie", auth.TokenCookieName()+"="+tt.token)
}
err = auth.CheckToken(req)
if tt.wantErr {
ExpectTrue(t, err != nil)
} else {
ExpectNoError(t, err)
}
}
}
func TestUserPassLoginCallbackHandler(t *testing.T) {
type cred struct {
User string `json:"username"`
Pass string `json:"password"`
}
auth := newMockUserPassAuth()
tests := []struct {
creds cred
wantErr bool
}{
{
creds: cred{
User: "username",
Pass: "password",
},
wantErr: false,
},
{
creds: cred{
User: "username",
Pass: "wrong-password",
},
wantErr: true,
},
}
for _, tt := range tests {
w := httptest.NewRecorder()
req := &http.Request{
Host: "app.example.com",
Body: io.NopCloser(bytes.NewReader(E.Must(json.Marshal(tt.creds)))),
}
auth.LoginCallbackHandler(w, req)
if tt.wantErr {
ExpectEqual(t, w.Code, http.StatusUnauthorized)
} else {
setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
ExpectTrue(t, setCookie.Name == auth.TokenCookieName())
ExpectTrue(t, setCookie.Value != "")
ExpectEqual(t, setCookie.Domain, "example.com")
ExpectEqual(t, setCookie.Path, "/")
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
ExpectEqual(t, setCookie.HttpOnly, true)
ExpectEqual(t, w.Code, http.StatusOK)
}
}
}

View file

@ -17,11 +17,6 @@ type oidcMiddleware struct {
var OIDC = NewMiddleware[oidcMiddleware]()
const (
OIDCMiddlewareCallbackPath = "/auth/callback"
OIDCLogoutPath = "/auth/logout"
)
func (amw *oidcMiddleware) finalize() error {
if !auth.IsOIDCEnabled() {
return E.New("OIDC not enabled but Auth middleware is used")
@ -31,14 +26,14 @@ func (amw *oidcMiddleware) finalize() error {
return err
}
authProvider.SetOverrideHostEnabled(true)
authProvider.SetIsMiddleware(true)
if len(amw.AllowedUsers) > 0 {
authProvider.SetAllowedUsers(amw.AllowedUsers)
}
amw.authMux = http.NewServeMux()
amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
})
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
@ -48,11 +43,11 @@ func (amw *oidcMiddleware) finalize() error {
}
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if err := amw.auth.CheckToken(w, r); err != nil {
if err := amw.auth.CheckToken(r); err != nil {
amw.authMux.ServeHTTP(w, r)
return false
}
if r.URL.Path == OIDCLogoutPath {
if r.URL.Path == auth.OIDCLogoutPath {
amw.logoutHandler(w, r)
return false
}