mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-04 02:42:34 +02:00
fix tests and callbackURL
This commit is contained in:
parent
c5e0ac6f38
commit
b44c8586cc
7 changed files with 152 additions and 27 deletions
|
@ -43,7 +43,7 @@ func IsOIDCEnabled() bool {
|
||||||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
if IsEnabled() {
|
if IsEnabled() {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if err := defaultAuth.CheckToken(w, r); err != nil {
|
if err := defaultAuth.CheckToken(r); err != nil {
|
||||||
U.RespondError(w, err, http.StatusUnauthorized)
|
U.RespondError(w, err, http.StatusUnauthorized)
|
||||||
} else {
|
} else {
|
||||||
next(w, r)
|
next(w, r)
|
||||||
|
|
|
@ -23,11 +23,16 @@ type OIDCProvider struct {
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oidcVerifier *oidc.IDTokenVerifier
|
oidcVerifier *oidc.IDTokenVerifier
|
||||||
allowedUsers []string
|
allowedUsers []string
|
||||||
overrideHost bool
|
isMiddleware bool
|
||||||
}
|
}
|
||||||
|
|
||||||
const CookieOauthState = "godoxy_oidc_state"
|
const CookieOauthState = "godoxy_oidc_state"
|
||||||
|
|
||||||
|
const (
|
||||||
|
OIDCMiddlewareCallbackPath = "/auth/callback"
|
||||||
|
OIDCLogoutPath = "/auth/logout"
|
||||||
|
)
|
||||||
|
|
||||||
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) {
|
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) {
|
||||||
if len(allowedUsers) == 0 {
|
if len(allowedUsers) == 0 {
|
||||||
return nil, errors.New("OIDC allowed users must not be empty")
|
return nil, errors.New("OIDC allowed users must not be empty")
|
||||||
|
@ -69,15 +74,18 @@ func (auth *OIDCProvider) TokenCookieName() string {
|
||||||
return "godoxy_oidc_token"
|
return "godoxy_oidc_token"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) SetOverrideHostEnabled(enabled bool) {
|
func (auth *OIDCProvider) SetIsMiddleware(enabled bool) {
|
||||||
auth.overrideHost = enabled
|
auth.isMiddleware = enabled
|
||||||
|
if auth.isMiddleware {
|
||||||
|
auth.oauthConfig.RedirectURL = OIDCMiddlewareCallbackPath
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
||||||
auth.allowedUsers = users
|
auth.allowedUsers = users
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *OIDCProvider) CheckToken(w http.ResponseWriter, r *http.Request) error {
|
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||||
token, err := r.Cookie(auth.TokenCookieName())
|
token, err := r.Cookie(auth.TokenCookieName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrMissingToken
|
return ErrMissingToken
|
||||||
|
@ -137,7 +145,7 @@ func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Reque
|
||||||
})
|
})
|
||||||
|
|
||||||
redirURL := auth.oauthConfig.AuthCodeURL(state)
|
redirURL := auth.oauthConfig.AuthCodeURL(state)
|
||||||
if auth.overrideHost {
|
if auth.isMiddleware {
|
||||||
u, err := r.URL.Parse(redirURL)
|
u, err := r.URL.Parse(redirURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||||
|
@ -165,12 +173,13 @@ func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Re
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.URL.Query().Get("state") != state.Value {
|
query := r.URL.Query()
|
||||||
|
if query.Get("state") != state.Value {
|
||||||
U.HandleErr(w, r, E.New("invalid oauth state"), http.StatusBadRequest)
|
U.HandleErr(w, r, E.New("invalid oauth state"), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
code := r.URL.Query().Get("code")
|
code := query.Get("code")
|
||||||
oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code)
|
oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
|
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
|
||||||
|
|
|
@ -8,7 +8,10 @@ import (
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// setupMockOIDC configures mock OIDC provider for testing.
|
// setupMockOIDC configures mock OIDC provider for testing.
|
||||||
|
@ -130,10 +133,12 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.wantStatus == http.StatusTemporaryRedirect {
|
if tt.wantStatus == http.StatusTemporaryRedirect {
|
||||||
cookie := w.Header().Get("Set-Cookie")
|
setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
||||||
if cookie == "" {
|
ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName())
|
||||||
t.Error("OIDCCallbackHandler() missing token cookie")
|
ExpectTrue(t, setCookie.Value != "")
|
||||||
}
|
ExpectEqual(t, setCookie.Path, "/")
|
||||||
|
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
||||||
|
ExpectEqual(t, setCookie.HttpOnly, true)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
|
|
||||||
type Provider interface {
|
type Provider interface {
|
||||||
TokenCookieName() string
|
TokenCookieName() string
|
||||||
CheckToken(w http.ResponseWriter, r *http.Request) error
|
CheckToken(r *http.Request) error
|
||||||
RedirectLoginPage(w http.ResponseWriter, r *http.Request)
|
RedirectLoginPage(w http.ResponseWriter, r *http.Request)
|
||||||
LoginCallbackHandler(w http.ResponseWriter, r *http.Request)
|
LoginCallbackHandler(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ func (auth *UserPassAuth) TokenCookieName() string {
|
||||||
return "godoxy_token"
|
return "godoxy_token"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (token string, err error) {
|
func (auth *UserPassAuth) NewToken() (token string, err error) {
|
||||||
claim := &UserPassClaims{
|
claim := &UserPassClaims{
|
||||||
Username: auth.username,
|
Username: auth.username,
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
@ -73,7 +73,7 @@ func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (t
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *UserPassAuth) CheckToken(w http.ResponseWriter, r *http.Request) error {
|
func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
||||||
jwtCookie, err := r.Cookie(auth.TokenCookieName())
|
jwtCookie, err := r.Cookie(auth.TokenCookieName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ErrMissingToken
|
return ErrMissingToken
|
||||||
|
@ -118,7 +118,7 @@ func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Re
|
||||||
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token, err := auth.CreateToken(w, r)
|
token, err := auth.NewToken()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -132,7 +132,7 @@ func (auth *UserPassAuth) validatePassword(user, pass string) error {
|
||||||
return ErrInvalidUsername.Subject(user)
|
return ErrInvalidUsername.Subject(user)
|
||||||
}
|
}
|
||||||
if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil {
|
if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil {
|
||||||
return ErrInvalidPassword.Subject(pass)
|
return ErrInvalidPassword.With(err).Subject(pass)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
116
internal/api/v1/auth/userpass_test.go
Normal file
116
internal/api/v1/auth/userpass_test.go
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newMockUserPassAuth() *UserPassAuth {
|
||||||
|
return &UserPassAuth{
|
||||||
|
username: "username",
|
||||||
|
pwdHash: E.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)),
|
||||||
|
secret: []byte("abcdefghijklmnopqrstuvwxyz"),
|
||||||
|
tokenTTL: time.Hour,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserPassValidateCredentials(t *testing.T) {
|
||||||
|
auth := newMockUserPassAuth()
|
||||||
|
err := auth.validatePassword("username", "password")
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
err = auth.validatePassword("username", "wrong-password")
|
||||||
|
ExpectError(t, ErrInvalidPassword, err)
|
||||||
|
err = auth.validatePassword("wrong-username", "password")
|
||||||
|
ExpectError(t, ErrInvalidUsername, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserPassCheckToken(t *testing.T) {
|
||||||
|
auth := newMockUserPassAuth()
|
||||||
|
token, err := auth.NewToken()
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
tests := []struct {
|
||||||
|
token string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
token: token,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token: "invalid-token",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := &http.Request{Header: http.Header{}}
|
||||||
|
if tt.token != "" {
|
||||||
|
req.Header.Set("Cookie", auth.TokenCookieName()+"="+tt.token)
|
||||||
|
}
|
||||||
|
err = auth.CheckToken(req)
|
||||||
|
if tt.wantErr {
|
||||||
|
ExpectTrue(t, err != nil)
|
||||||
|
} else {
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserPassLoginCallbackHandler(t *testing.T) {
|
||||||
|
type cred struct {
|
||||||
|
User string `json:"username"`
|
||||||
|
Pass string `json:"password"`
|
||||||
|
}
|
||||||
|
auth := newMockUserPassAuth()
|
||||||
|
tests := []struct {
|
||||||
|
creds cred
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
creds: cred{
|
||||||
|
User: "username",
|
||||||
|
Pass: "password",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
creds: cred{
|
||||||
|
User: "username",
|
||||||
|
Pass: "wrong-password",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := &http.Request{
|
||||||
|
Host: "app.example.com",
|
||||||
|
Body: io.NopCloser(bytes.NewReader(E.Must(json.Marshal(tt.creds)))),
|
||||||
|
}
|
||||||
|
auth.LoginCallbackHandler(w, req)
|
||||||
|
if tt.wantErr {
|
||||||
|
ExpectEqual(t, w.Code, http.StatusUnauthorized)
|
||||||
|
} else {
|
||||||
|
setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
||||||
|
ExpectTrue(t, setCookie.Name == auth.TokenCookieName())
|
||||||
|
ExpectTrue(t, setCookie.Value != "")
|
||||||
|
ExpectEqual(t, setCookie.Domain, "example.com")
|
||||||
|
ExpectEqual(t, setCookie.Path, "/")
|
||||||
|
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
||||||
|
ExpectEqual(t, setCookie.HttpOnly, true)
|
||||||
|
ExpectEqual(t, w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,11 +17,6 @@ type oidcMiddleware struct {
|
||||||
|
|
||||||
var OIDC = NewMiddleware[oidcMiddleware]()
|
var OIDC = NewMiddleware[oidcMiddleware]()
|
||||||
|
|
||||||
const (
|
|
||||||
OIDCMiddlewareCallbackPath = "/auth/callback"
|
|
||||||
OIDCLogoutPath = "/auth/logout"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (amw *oidcMiddleware) finalize() error {
|
func (amw *oidcMiddleware) finalize() error {
|
||||||
if !auth.IsOIDCEnabled() {
|
if !auth.IsOIDCEnabled() {
|
||||||
return E.New("OIDC not enabled but Auth middleware is used")
|
return E.New("OIDC not enabled but Auth middleware is used")
|
||||||
|
@ -31,14 +26,14 @@ func (amw *oidcMiddleware) finalize() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
authProvider.SetOverrideHostEnabled(true)
|
authProvider.SetIsMiddleware(true)
|
||||||
if len(amw.AllowedUsers) > 0 {
|
if len(amw.AllowedUsers) > 0 {
|
||||||
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
||||||
}
|
}
|
||||||
|
|
||||||
amw.authMux = http.NewServeMux()
|
amw.authMux = http.NewServeMux()
|
||||||
amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||||
amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
})
|
})
|
||||||
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
||||||
|
@ -48,11 +43,11 @@ func (amw *oidcMiddleware) finalize() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
if err := amw.auth.CheckToken(w, r); err != nil {
|
if err := amw.auth.CheckToken(r); err != nil {
|
||||||
amw.authMux.ServeHTTP(w, r)
|
amw.authMux.ServeHTTP(w, r)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
if r.URL.Path == OIDCLogoutPath {
|
if r.URL.Path == auth.OIDCLogoutPath {
|
||||||
amw.logoutHandler(w, r)
|
amw.logoutHandler(w, r)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue