mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-01 09:32:35 +02:00
fix tests and callbackURL
This commit is contained in:
parent
c5e0ac6f38
commit
b44c8586cc
7 changed files with 152 additions and 27 deletions
|
@ -43,7 +43,7 @@ func IsOIDCEnabled() bool {
|
|||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
if IsEnabled() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := defaultAuth.CheckToken(w, r); err != nil {
|
||||
if err := defaultAuth.CheckToken(r); err != nil {
|
||||
U.RespondError(w, err, http.StatusUnauthorized)
|
||||
} else {
|
||||
next(w, r)
|
||||
|
|
|
@ -23,11 +23,16 @@ type OIDCProvider struct {
|
|||
oidcProvider *oidc.Provider
|
||||
oidcVerifier *oidc.IDTokenVerifier
|
||||
allowedUsers []string
|
||||
overrideHost bool
|
||||
isMiddleware bool
|
||||
}
|
||||
|
||||
const CookieOauthState = "godoxy_oidc_state"
|
||||
|
||||
const (
|
||||
OIDCMiddlewareCallbackPath = "/auth/callback"
|
||||
OIDCLogoutPath = "/auth/logout"
|
||||
)
|
||||
|
||||
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) {
|
||||
if len(allowedUsers) == 0 {
|
||||
return nil, errors.New("OIDC allowed users must not be empty")
|
||||
|
@ -69,15 +74,18 @@ func (auth *OIDCProvider) TokenCookieName() string {
|
|||
return "godoxy_oidc_token"
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) SetOverrideHostEnabled(enabled bool) {
|
||||
auth.overrideHost = enabled
|
||||
func (auth *OIDCProvider) SetIsMiddleware(enabled bool) {
|
||||
auth.isMiddleware = enabled
|
||||
if auth.isMiddleware {
|
||||
auth.oauthConfig.RedirectURL = OIDCMiddlewareCallbackPath
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
||||
auth.allowedUsers = users
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) CheckToken(w http.ResponseWriter, r *http.Request) error {
|
||||
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||
token, err := r.Cookie(auth.TokenCookieName())
|
||||
if err != nil {
|
||||
return ErrMissingToken
|
||||
|
@ -137,7 +145,7 @@ func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Reque
|
|||
})
|
||||
|
||||
redirURL := auth.oauthConfig.AuthCodeURL(state)
|
||||
if auth.overrideHost {
|
||||
if auth.isMiddleware {
|
||||
u, err := r.URL.Parse(redirURL)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
|
@ -165,12 +173,13 @@ func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Re
|
|||
return
|
||||
}
|
||||
|
||||
if r.URL.Query().Get("state") != state.Value {
|
||||
query := r.URL.Query()
|
||||
if query.Get("state") != state.Value {
|
||||
U.HandleErr(w, r, E.New("invalid oauth state"), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
code := query.Get("code")
|
||||
oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
|
||||
|
|
|
@ -8,7 +8,10 @@ import (
|
|||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
)
|
||||
|
||||
// setupMockOIDC configures mock OIDC provider for testing.
|
||||
|
@ -130,10 +133,12 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
|||
}
|
||||
|
||||
if tt.wantStatus == http.StatusTemporaryRedirect {
|
||||
cookie := w.Header().Get("Set-Cookie")
|
||||
if cookie == "" {
|
||||
t.Error("OIDCCallbackHandler() missing token cookie")
|
||||
}
|
||||
setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
||||
ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName())
|
||||
ExpectTrue(t, setCookie.Value != "")
|
||||
ExpectEqual(t, setCookie.Path, "/")
|
||||
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
||||
ExpectEqual(t, setCookie.HttpOnly, true)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
|
||||
type Provider interface {
|
||||
TokenCookieName() string
|
||||
CheckToken(w http.ResponseWriter, r *http.Request) error
|
||||
CheckToken(r *http.Request) error
|
||||
RedirectLoginPage(w http.ResponseWriter, r *http.Request)
|
||||
LoginCallbackHandler(w http.ResponseWriter, r *http.Request)
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ func (auth *UserPassAuth) TokenCookieName() string {
|
|||
return "godoxy_token"
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (token string, err error) {
|
||||
func (auth *UserPassAuth) NewToken() (token string, err error) {
|
||||
claim := &UserPassClaims{
|
||||
Username: auth.username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
|
@ -73,7 +73,7 @@ func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (t
|
|||
return token, nil
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) CheckToken(w http.ResponseWriter, r *http.Request) error {
|
||||
func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
||||
jwtCookie, err := r.Cookie(auth.TokenCookieName())
|
||||
if err != nil {
|
||||
return ErrMissingToken
|
||||
|
@ -118,7 +118,7 @@ func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Re
|
|||
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
token, err := auth.CreateToken(w, r)
|
||||
token, err := auth.NewToken()
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
|
@ -132,7 +132,7 @@ func (auth *UserPassAuth) validatePassword(user, pass string) error {
|
|||
return ErrInvalidUsername.Subject(user)
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil {
|
||||
return ErrInvalidPassword.Subject(pass)
|
||||
return ErrInvalidPassword.With(err).Subject(pass)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
116
internal/api/v1/auth/userpass_test.go
Normal file
116
internal/api/v1/auth/userpass_test.go
Normal file
|
@ -0,0 +1,116 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
func newMockUserPassAuth() *UserPassAuth {
|
||||
return &UserPassAuth{
|
||||
username: "username",
|
||||
pwdHash: E.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)),
|
||||
secret: []byte("abcdefghijklmnopqrstuvwxyz"),
|
||||
tokenTTL: time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPassValidateCredentials(t *testing.T) {
|
||||
auth := newMockUserPassAuth()
|
||||
err := auth.validatePassword("username", "password")
|
||||
ExpectNoError(t, err)
|
||||
err = auth.validatePassword("username", "wrong-password")
|
||||
ExpectError(t, ErrInvalidPassword, err)
|
||||
err = auth.validatePassword("wrong-username", "password")
|
||||
ExpectError(t, ErrInvalidUsername, err)
|
||||
}
|
||||
|
||||
func TestUserPassCheckToken(t *testing.T) {
|
||||
auth := newMockUserPassAuth()
|
||||
token, err := auth.NewToken()
|
||||
ExpectNoError(t, err)
|
||||
tests := []struct {
|
||||
token string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
token: token,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
token: "invalid-token",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
token: "",
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
req := &http.Request{Header: http.Header{}}
|
||||
if tt.token != "" {
|
||||
req.Header.Set("Cookie", auth.TokenCookieName()+"="+tt.token)
|
||||
}
|
||||
err = auth.CheckToken(req)
|
||||
if tt.wantErr {
|
||||
ExpectTrue(t, err != nil)
|
||||
} else {
|
||||
ExpectNoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserPassLoginCallbackHandler(t *testing.T) {
|
||||
type cred struct {
|
||||
User string `json:"username"`
|
||||
Pass string `json:"password"`
|
||||
}
|
||||
auth := newMockUserPassAuth()
|
||||
tests := []struct {
|
||||
creds cred
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
creds: cred{
|
||||
User: "username",
|
||||
Pass: "password",
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
creds: cred{
|
||||
User: "username",
|
||||
Pass: "wrong-password",
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
w := httptest.NewRecorder()
|
||||
req := &http.Request{
|
||||
Host: "app.example.com",
|
||||
Body: io.NopCloser(bytes.NewReader(E.Must(json.Marshal(tt.creds)))),
|
||||
}
|
||||
auth.LoginCallbackHandler(w, req)
|
||||
if tt.wantErr {
|
||||
ExpectEqual(t, w.Code, http.StatusUnauthorized)
|
||||
} else {
|
||||
setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
||||
ExpectTrue(t, setCookie.Name == auth.TokenCookieName())
|
||||
ExpectTrue(t, setCookie.Value != "")
|
||||
ExpectEqual(t, setCookie.Domain, "example.com")
|
||||
ExpectEqual(t, setCookie.Path, "/")
|
||||
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
||||
ExpectEqual(t, setCookie.HttpOnly, true)
|
||||
ExpectEqual(t, w.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -17,11 +17,6 @@ type oidcMiddleware struct {
|
|||
|
||||
var OIDC = NewMiddleware[oidcMiddleware]()
|
||||
|
||||
const (
|
||||
OIDCMiddlewareCallbackPath = "/auth/callback"
|
||||
OIDCLogoutPath = "/auth/logout"
|
||||
)
|
||||
|
||||
func (amw *oidcMiddleware) finalize() error {
|
||||
if !auth.IsOIDCEnabled() {
|
||||
return E.New("OIDC not enabled but Auth middleware is used")
|
||||
|
@ -31,14 +26,14 @@ func (amw *oidcMiddleware) finalize() error {
|
|||
return err
|
||||
}
|
||||
|
||||
authProvider.SetOverrideHostEnabled(true)
|
||||
authProvider.SetIsMiddleware(true)
|
||||
if len(amw.AllowedUsers) > 0 {
|
||||
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
||||
}
|
||||
|
||||
amw.authMux = http.NewServeMux()
|
||||
amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||
amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
})
|
||||
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
||||
|
@ -48,11 +43,11 @@ func (amw *oidcMiddleware) finalize() error {
|
|||
}
|
||||
|
||||
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
if err := amw.auth.CheckToken(w, r); err != nil {
|
||||
if err := amw.auth.CheckToken(r); err != nil {
|
||||
amw.authMux.ServeHTTP(w, r)
|
||||
return false
|
||||
}
|
||||
if r.URL.Path == OIDCLogoutPath {
|
||||
if r.URL.Path == auth.OIDCLogoutPath {
|
||||
amw.logoutHandler(w, r)
|
||||
return false
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue