mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-04 02:42:34 +02:00
implement OIDC middleware
This commit is contained in:
parent
2af2346e35
commit
bb0ee5d7a9
15 changed files with 321 additions and 110 deletions
|
@ -2,6 +2,7 @@
|
||||||
TZ=ETC/UTC
|
TZ=ETC/UTC
|
||||||
|
|
||||||
# generate secret with `openssl rand -base64 32`
|
# generate secret with `openssl rand -base64 32`
|
||||||
|
# used for both user password authentication and OIDC
|
||||||
GODOXY_API_JWT_SECRET=
|
GODOXY_API_JWT_SECRET=
|
||||||
|
|
||||||
# the JWT token time-to-live
|
# the JWT token time-to-live
|
||||||
|
@ -11,6 +12,7 @@ GODOXY_API_JWT_TOKEN_TTL=1h
|
||||||
# Important: If using OIDC authentication, the API_USER must match the username
|
# Important: If using OIDC authentication, the API_USER must match the username
|
||||||
# provided by the OIDC provider.
|
# provided by the OIDC provider.
|
||||||
GODOXY_API_USER=admin
|
GODOXY_API_USER=admin
|
||||||
|
# Password is not required for OIDC authentication
|
||||||
GODOXY_API_PASSWORD=password
|
GODOXY_API_PASSWORD=password
|
||||||
|
|
||||||
# OIDC Configuration (optional)
|
# OIDC Configuration (optional)
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -39,7 +39,7 @@ profile:
|
||||||
|
|
||||||
run: build
|
run: build
|
||||||
sudo setcap CAP_NET_BIND_SERVICE=+eip bin/godoxy
|
sudo setcap CAP_NET_BIND_SERVICE=+eip bin/godoxy
|
||||||
bin/godoxy
|
[ -f .env ] && godotenv -f .env bin/godoxy || bin/godoxy
|
||||||
|
|
||||||
mtrace:
|
mtrace:
|
||||||
bin/godoxy debug-ls-mtrace > mtrace.json
|
bin/godoxy debug-ls-mtrace > mtrace.json
|
||||||
|
|
10
cmd/main.go
10
cmd/main.go
|
@ -9,7 +9,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/yusing/go-proxy/internal"
|
"github.com/yusing/go-proxy/internal"
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/query"
|
"github.com/yusing/go-proxy/internal/api/v1/query"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
"github.com/yusing/go-proxy/internal/config"
|
"github.com/yusing/go-proxy/internal/config"
|
||||||
|
@ -112,15 +111,6 @@ func main() {
|
||||||
cfg.Start()
|
cfg.Start()
|
||||||
config.WatchChanges()
|
config.WatchChanges()
|
||||||
|
|
||||||
if !auth.IsEnabled() {
|
|
||||||
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
|
|
||||||
} else {
|
|
||||||
// Initialize authentication providers
|
|
||||||
if err := auth.Initialize(); err != nil {
|
|
||||||
logging.Fatal().Err(err).Msg("Failed to initialize authentication providers")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sig := make(chan os.Signal, 1)
|
sig := make(chan os.Signal, 1)
|
||||||
signal.Notify(sig, syscall.SIGINT)
|
signal.Notify(sig, syscall.SIGINT)
|
||||||
signal.Notify(sig, syscall.SIGTERM)
|
signal.Notify(sig, syscall.SIGTERM)
|
||||||
|
|
|
@ -23,8 +23,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
|
||||||
mux.HandleFunc("GET", "/v1", v1.Index)
|
mux.HandleFunc("GET", "/v1", v1.Index)
|
||||||
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
||||||
mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler)
|
mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler)
|
||||||
mux.HandleFunc("GET", "/v1/auth/redirect", auth.AuthRedirectHandler)
|
mux.HandleFunc("GET", "/v1/auth/redirect", auth.APIAuthRedirectHandler)
|
||||||
mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler)
|
mux.HandleFunc("GET", "/v1/auth/callback", auth.APIOIDCCallbackHandler)
|
||||||
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
|
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
|
||||||
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
|
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
|
||||||
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
|
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
|
||||||
|
|
|
@ -2,6 +2,7 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -9,6 +10,7 @@ import (
|
||||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -23,29 +25,59 @@ type (
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// Initialize sets up authentication providers.
|
// init sets up authentication providers.
|
||||||
func Initialize() error {
|
func init() {
|
||||||
|
if !IsEnabled() {
|
||||||
|
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
|
||||||
|
return
|
||||||
|
}
|
||||||
// Initialize OIDC if configured.
|
// Initialize OIDC if configured.
|
||||||
if common.OIDCIssuerURL != "" {
|
if common.OIDCIssuerURL != "" {
|
||||||
return InitOIDC(
|
if err := initOIDC(
|
||||||
common.OIDCIssuerURL,
|
common.OIDCIssuerURL,
|
||||||
common.OIDCClientID,
|
common.OIDCClientID,
|
||||||
common.OIDCClientSecret,
|
common.OIDCClientSecret,
|
||||||
common.OIDCRedirectURL,
|
common.OIDCRedirectURL,
|
||||||
)
|
); err != nil {
|
||||||
|
logging.Fatal().Err(err).Msg("failed to initialize OIDC provider")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsEnabled() bool {
|
func IsEnabled() bool {
|
||||||
return common.APIJWTSecret != nil || common.OIDCIssuerURL != ""
|
return common.APIJWTSecret != nil || IsOIDCEnabled()
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthRedirectHandler handles redirect to login page or OIDC login base on configuration.
|
func IsOIDCEnabled() bool {
|
||||||
func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
|
return common.OIDCIssuerURL != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// cookieFQDN returns the fully qualified domain name of the request host
|
||||||
|
// with subdomain stripped.
|
||||||
|
//
|
||||||
|
// If the request host does not have a subdomain,
|
||||||
|
// an empty string is returned
|
||||||
|
//
|
||||||
|
// "abc.example.com" -> "example.com"
|
||||||
|
// "example.com" -> ""
|
||||||
|
func cookieFQDN(r *http.Request) string {
|
||||||
|
host, _, err := net.SplitHostPort(r.Host)
|
||||||
|
if err != nil {
|
||||||
|
host = r.Host
|
||||||
|
}
|
||||||
|
parts := strutils.SplitRune(host, '.')
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parts[0] = ""
|
||||||
|
return strutils.JoinRune(parts, '.')
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIAuthRedirectHandler handles API redirect to login page or OIDC login base on configuration.
|
||||||
|
func APIAuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
switch {
|
switch {
|
||||||
case oauthConfig != nil:
|
case apiOAuth != nil:
|
||||||
RedirectOIDC(w, r)
|
apiOAuth.RedirectOIDC(w, r)
|
||||||
return
|
return
|
||||||
case common.APIJWTSecret != nil:
|
case common.APIJWTSecret != nil:
|
||||||
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
||||||
|
@ -55,7 +87,7 @@ func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
|
func setAuthenticatedCookie(w http.ResponseWriter, r *http.Request, username string) error {
|
||||||
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
|
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
|
||||||
claim := &Claims{
|
claim := &Claims{
|
||||||
Username: username,
|
Username: username,
|
||||||
|
@ -72,9 +104,10 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
|
||||||
Name: CookieToken,
|
Name: CookieToken,
|
||||||
Value: tokenStr,
|
Value: tokenStr,
|
||||||
Expires: expiresAt,
|
Expires: expiresAt,
|
||||||
|
Domain: cookieFQDN(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: true,
|
Secure: true,
|
||||||
SameSite: http.SameSiteStrictMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
|
@ -84,20 +117,22 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
|
||||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
http.SetCookie(w, &http.Cookie{
|
http.SetCookie(w, &http.Cookie{
|
||||||
Name: CookieToken,
|
Name: CookieToken,
|
||||||
Value: "",
|
MaxAge: -1,
|
||||||
Expires: time.Unix(0, 0),
|
Domain: cookieFQDN(r),
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
Secure: true,
|
Secure: true,
|
||||||
SameSite: http.SameSiteStrictMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
})
|
})
|
||||||
AuthRedirectHandler(w, r)
|
APIAuthRedirectHandler(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
if IsEnabled() {
|
if IsEnabled() {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if checkToken(w, r) {
|
if err := CheckToken(w, r); err != nil {
|
||||||
|
U.RespondError(w, err, http.StatusUnauthorized)
|
||||||
|
} else {
|
||||||
next(w, r)
|
next(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -105,11 +140,10 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
|
func CheckToken(w http.ResponseWriter, r *http.Request) error {
|
||||||
tokenCookie, err := r.Cookie(CookieToken)
|
tokenCookie, err := r.Cookie(CookieToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.RespondError(w, E.New("missing token"), http.StatusUnauthorized)
|
return E.New("missing token")
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
var claims Claims
|
var claims Claims
|
||||||
token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
@ -118,22 +152,17 @@ func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
|
||||||
}
|
}
|
||||||
return common.APIJWTSecret, nil
|
return common.APIJWTSecret, nil
|
||||||
})
|
})
|
||||||
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
break
|
|
||||||
case !token.Valid:
|
|
||||||
err = E.New("invalid token")
|
|
||||||
case claims.Username != common.APIUser:
|
|
||||||
err = E.New("username mismatch").Subject(claims.Username)
|
|
||||||
case claims.ExpiresAt.Before(time.Now()):
|
|
||||||
err = E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.RespondError(w, err, http.StatusForbidden)
|
return err
|
||||||
return false
|
}
|
||||||
|
switch {
|
||||||
|
case !token.Valid:
|
||||||
|
return E.New("invalid token")
|
||||||
|
case claims.Username != common.APIUser:
|
||||||
|
return E.New("username mismatch").Subject(claims.Username)
|
||||||
|
case claims.ExpiresAt.Before(time.Now()):
|
||||||
|
return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,42 +13,66 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type OIDCProvider struct {
|
||||||
oauthConfig *oauth2.Config
|
oauthConfig *oauth2.Config
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oidcVerifier *oidc.IDTokenVerifier
|
oidcVerifier *oidc.IDTokenVerifier
|
||||||
|
overrideHost bool
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
apiOAuth *OIDCProvider
|
||||||
|
APIOIDCCallbackHandler http.HandlerFunc
|
||||||
)
|
)
|
||||||
|
|
||||||
// InitOIDC initializes the OIDC provider.
|
// initOIDC initializes the OIDC provider.
|
||||||
func InitOIDC(issuerURL, clientID, clientSecret, redirectURL string) error {
|
func initOIDC(issuerURL, clientID, clientSecret, redirectURL string) (err error) {
|
||||||
if issuerURL == "" {
|
if issuerURL == "" {
|
||||||
return nil // OIDC not configured
|
return nil // OIDC not configured
|
||||||
}
|
}
|
||||||
|
|
||||||
|
apiOAuth, err = NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL)
|
||||||
|
APIOIDCCallbackHandler = apiOAuth.OIDCCallbackHandler
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string) (*OIDCProvider, error) {
|
||||||
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize OIDC provider: %w", err)
|
return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcProvider = provider
|
return &OIDCProvider{
|
||||||
oidcVerifier = provider.Verifier(&oidc.Config{
|
oauthConfig: &oauth2.Config{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
})
|
ClientSecret: clientSecret,
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
Scopes: strutils.CommaSeperatedList(common.OIDCScopes),
|
||||||
|
},
|
||||||
|
oidcProvider: provider,
|
||||||
|
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
oauthConfig = &oauth2.Config{
|
func NewOIDCProviderFromEnv(redirectURL string) (*OIDCProvider, error) {
|
||||||
ClientID: clientID,
|
return NewOIDCProvider(
|
||||||
ClientSecret: clientSecret,
|
common.OIDCIssuerURL,
|
||||||
RedirectURL: redirectURL,
|
common.OIDCClientID,
|
||||||
Endpoint: provider.Endpoint(),
|
common.OIDCClientSecret,
|
||||||
Scopes: strutils.CommaSeperatedList(common.OIDCScopes),
|
redirectURL,
|
||||||
}
|
)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
func (provider *OIDCProvider) SetOverrideHostEnabled(enabled bool) {
|
||||||
|
provider.overrideHost = enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// RedirectOIDC initiates the OIDC login flow.
|
// RedirectOIDC initiates the OIDC login flow.
|
||||||
func RedirectOIDC(w http.ResponseWriter, r *http.Request) {
|
func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Request) {
|
||||||
if oauthConfig == nil {
|
if provider == nil {
|
||||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -59,18 +83,29 @@ func RedirectOIDC(w http.ResponseWriter, r *http.Request) {
|
||||||
Value: state,
|
Value: state,
|
||||||
MaxAge: 300,
|
MaxAge: 300,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteNoneMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
Secure: true,
|
Secure: true,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
})
|
})
|
||||||
|
|
||||||
url := oauthConfig.AuthCodeURL(state)
|
redirURL := provider.oauthConfig.AuthCodeURL(state)
|
||||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
if provider.overrideHost {
|
||||||
|
u, err := r.URL.Parse(redirURL)
|
||||||
|
if err != nil {
|
||||||
|
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("redirect_uri", "https://"+r.Host+q.Get("redirect_uri"))
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
redirURL = u.String()
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, redirURL, http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OIDCCallbackHandler handles the OIDC callback.
|
// OIDCCallbackHandler handles the OIDC callback.
|
||||||
func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if oauthConfig == nil {
|
if provider == nil {
|
||||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -81,7 +116,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if oidcProvider == nil {
|
if provider.oidcProvider == nil {
|
||||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -98,7 +133,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
code := r.URL.Query().Get("code")
|
code := r.URL.Query().Get("code")
|
||||||
oauth2Token, err := oauthConfig.Exchange(r.Context(), code)
|
oauth2Token, err := provider.oauthConfig.Exchange(r.Context(), code)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
|
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -110,7 +145,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err := oidcVerifier.Verify(r.Context(), rawIDToken)
|
idToken, err := provider.oidcVerifier.Verify(r.Context(), rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError)
|
U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -125,7 +160,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := setAuthenticatedCookie(w, claims.Username); err != nil {
|
if err := setAuthenticatedCookie(w, r, claims.Username); err != nil {
|
||||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -148,7 +183,7 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test JWT token
|
// Create test JWT token
|
||||||
if err := setAuthenticatedCookie(w, "test-user"); err != nil {
|
if err := setAuthenticatedCookie(w, r, "test-user"); err != nil {
|
||||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,22 +14,22 @@ import (
|
||||||
func setupMockOIDC(t *testing.T) {
|
func setupMockOIDC(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
oauthConfig = &oauth2.Config{
|
apiOAuth = &OIDCProvider{
|
||||||
ClientID: "test-client",
|
oauthConfig: &oauth2.Config{
|
||||||
ClientSecret: "test-secret",
|
ClientID: "test-client",
|
||||||
RedirectURL: "http://localhost/callback",
|
ClientSecret: "test-secret",
|
||||||
Endpoint: oauth2.Endpoint{
|
RedirectURL: "http://localhost/callback",
|
||||||
AuthURL: "http://mock-provider/auth",
|
Endpoint: oauth2.Endpoint{
|
||||||
TokenURL: "http://mock-provider/token",
|
AuthURL: "http://mock-provider/auth",
|
||||||
|
TokenURL: "http://mock-provider/token",
|
||||||
|
},
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
},
|
},
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cleanup() {
|
func cleanup() {
|
||||||
oauthConfig = nil
|
apiOAuth = nil
|
||||||
oidcProvider = nil
|
|
||||||
oidcVerifier = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOIDCLoginHandler(t *testing.T) {
|
func TestOIDCLoginHandler(t *testing.T) {
|
||||||
|
@ -65,13 +65,13 @@ func TestOIDCLoginHandler(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if !tt.configureOAuth {
|
if !tt.configureOAuth {
|
||||||
oauthConfig = nil
|
apiOAuth = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
|
req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
RedirectOIDC(w, req)
|
apiOAuth.RedirectOIDC(w, req)
|
||||||
|
|
||||||
if got := w.Code; got != tt.wantStatus {
|
if got := w.Code; got != tt.wantStatus {
|
||||||
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
|
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
|
||||||
|
@ -140,7 +140,7 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.configureOAuth {
|
if !tt.configureOAuth {
|
||||||
oauthConfig = nil
|
apiOAuth = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil)
|
req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil)
|
||||||
|
@ -152,7 +152,7 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
OIDCCallbackHandler(w, req)
|
apiOAuth.OIDCCallbackHandler(w, req)
|
||||||
|
|
||||||
if got := w.Code; got != tt.wantStatus {
|
if got := w.Code; got != tt.wantStatus {
|
||||||
t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus)
|
t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus)
|
||||||
|
@ -194,7 +194,7 @@ func TestInitOIDC(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Cleanup(cleanup)
|
t.Cleanup(cleanup)
|
||||||
err := InitOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL)
|
err := initOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,7 +37,7 @@ func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := setAuthenticatedCookie(w, creds.Username); err != nil {
|
if err := setAuthenticatedCookie(w, r, creds.Username); err != nil {
|
||||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -89,7 +89,11 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// Then scraper / scanners will know the subdomain is invalid.
|
// Then scraper / scanners will know the subdomain is invalid.
|
||||||
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
||||||
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
||||||
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
|
logger.Err(err).
|
||||||
|
Str("method", r.Method).
|
||||||
|
Str("url", r.URL.String()).
|
||||||
|
Str("remote", r.RemoteAddr).
|
||||||
|
Msg("request")
|
||||||
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
||||||
if ok {
|
if ok {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
@ -26,28 +27,50 @@ type (
|
||||||
name string
|
name string
|
||||||
construct ImplNewFunc
|
construct ImplNewFunc
|
||||||
impl any
|
impl any
|
||||||
|
// priority is only applied for ReverseProxy.
|
||||||
|
//
|
||||||
|
// Middleware compose follows the order of the slice
|
||||||
|
//
|
||||||
|
// Default is 10, 0 is the highest
|
||||||
|
priority int
|
||||||
}
|
}
|
||||||
|
ByPriority []*Middleware
|
||||||
|
|
||||||
RequestModifier interface {
|
RequestModifier interface {
|
||||||
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||||
}
|
}
|
||||||
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
||||||
MiddlewareWithSetup interface{ setup() }
|
MiddlewareWithSetup interface{ setup() }
|
||||||
MiddlewareFinalizer interface{ finalize() }
|
MiddlewareFinalizer interface{ finalize() }
|
||||||
|
MiddlewareFinalizerWithError interface {
|
||||||
|
finalize() error
|
||||||
|
}
|
||||||
MiddlewareWithTracer interface {
|
MiddlewareWithTracer interface {
|
||||||
enableTrace()
|
enableTrace()
|
||||||
getTracer() *Tracer
|
getTracer() *Tracer
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DefaultPriority = 10
|
||||||
|
|
||||||
|
func (m ByPriority) Len() int { return len(m) }
|
||||||
|
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
|
||||||
|
func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority }
|
||||||
|
|
||||||
func NewMiddleware[ImplType any]() *Middleware {
|
func NewMiddleware[ImplType any]() *Middleware {
|
||||||
// type check
|
// type check
|
||||||
switch any(new(ImplType)).(type) {
|
t := any(new(ImplType))
|
||||||
|
switch t.(type) {
|
||||||
case RequestModifier:
|
case RequestModifier:
|
||||||
case ResponseModifier:
|
case ResponseModifier:
|
||||||
default:
|
default:
|
||||||
panic("must implement RequestModifier or ResponseModifier")
|
panic("must implement RequestModifier or ResponseModifier")
|
||||||
}
|
}
|
||||||
|
_, hasFinializer := t.(MiddlewareFinalizer)
|
||||||
|
_, hasFinializerWithError := t.(MiddlewareFinalizerWithError)
|
||||||
|
if hasFinializer && hasFinializerWithError {
|
||||||
|
panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive")
|
||||||
|
}
|
||||||
return &Middleware{
|
return &Middleware{
|
||||||
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
||||||
construct: func() any { return new(ImplType) },
|
construct: func() any { return new(ImplType) },
|
||||||
|
@ -84,13 +107,29 @@ func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
|
||||||
if len(optsRaw) == 0 {
|
if len(optsRaw) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
priority, ok := optsRaw["priority"].(int)
|
||||||
|
if ok {
|
||||||
|
m.priority = priority
|
||||||
|
// remove priority for deserialization, restore later
|
||||||
|
delete(optsRaw, "priority")
|
||||||
|
defer func() {
|
||||||
|
optsRaw["priority"] = priority
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
m.priority = DefaultPriority
|
||||||
|
}
|
||||||
return utils.Deserialize(optsRaw, m.impl)
|
return utils.Deserialize(optsRaw, m.impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) finalize() {
|
func (m *Middleware) finalize() error {
|
||||||
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
||||||
finalizer.finalize()
|
finalizer.finalize()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok {
|
||||||
|
return finalizer.finalize()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
|
@ -105,7 +144,9 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
if err := mid.apply(optsRaw); err != nil {
|
if err := mid.apply(optsRaw); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mid.finalize()
|
if err := mid.finalize(); err != nil {
|
||||||
|
return nil, E.From(err)
|
||||||
|
}
|
||||||
return mid, nil
|
return mid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,8 +160,9 @@ func (m *Middleware) String() string {
|
||||||
|
|
||||||
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
||||||
return json.MarshalIndent(map[string]any{
|
return json.MarshalIndent(map[string]any{
|
||||||
"name": m.name,
|
"name": m.name,
|
||||||
"options": m.impl,
|
"options": m.impl,
|
||||||
|
"priority": m.priority,
|
||||||
}, "", " ")
|
}, "", " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,6 +235,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
|
||||||
}
|
}
|
||||||
|
|
||||||
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||||
|
sort.Sort(ByPriority(middlewares))
|
||||||
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
||||||
|
|
||||||
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
||||||
|
|
37
internal/net/http/middleware/middleware_test.go
Normal file
37
internal/net/http/middleware/middleware_test.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testPriority struct {
|
||||||
|
Value int `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var test = NewMiddleware[testPriority]()
|
||||||
|
|
||||||
|
func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
w.Header().Add("Test-Value", strconv.Itoa(t.Value))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewarePriority(t *testing.T) {
|
||||||
|
priorities := []int{4, 7, 9, 0}
|
||||||
|
chain := make([]*Middleware, len(priorities))
|
||||||
|
for i, p := range priorities {
|
||||||
|
mid, err := test.New(OptionsRaw{
|
||||||
|
"priority": p,
|
||||||
|
"value": i,
|
||||||
|
})
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
chain[i] = mid
|
||||||
|
}
|
||||||
|
res, err := newMiddlewaresTest(chain, nil)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
|
||||||
|
}
|
|
@ -14,6 +14,8 @@ import (
|
||||||
var allMiddlewares = map[string]*Middleware{
|
var allMiddlewares = map[string]*Middleware{
|
||||||
"redirecthttp": RedirectHTTP,
|
"redirecthttp": RedirectHTTP,
|
||||||
|
|
||||||
|
"auth": OIDC,
|
||||||
|
|
||||||
"request": ModifyRequest,
|
"request": ModifyRequest,
|
||||||
"modifyrequest": ModifyRequest,
|
"modifyrequest": ModifyRequest,
|
||||||
"response": ModifyResponse,
|
"response": ModifyResponse,
|
||||||
|
|
51
internal/net/http/middleware/oidc.go
Normal file
51
internal/net/http/middleware/oidc.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
type oidcMiddleware struct {
|
||||||
|
oauth *auth.OIDCProvider
|
||||||
|
authMux *http.ServeMux
|
||||||
|
}
|
||||||
|
|
||||||
|
var OIDC = NewMiddleware[oidcMiddleware]()
|
||||||
|
|
||||||
|
const (
|
||||||
|
OIDCMiddlewareCallbackPath = "/godoxy-auth-oidc/callback"
|
||||||
|
OIDCLogoutPath = "/logout"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (amw *oidcMiddleware) finalize() error {
|
||||||
|
if !auth.IsOIDCEnabled() {
|
||||||
|
return E.New("OIDC not enabled but Auth middleware is used")
|
||||||
|
}
|
||||||
|
provider, err := auth.NewOIDCProviderFromEnv(OIDCMiddlewareCallbackPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
provider.SetOverrideHostEnabled(true)
|
||||||
|
amw.oauth = provider
|
||||||
|
amw.authMux = http.NewServeMux()
|
||||||
|
amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, provider.OIDCCallbackHandler)
|
||||||
|
amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
})
|
||||||
|
amw.authMux.HandleFunc("/", provider.RedirectOIDC)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
if err, _ := auth.CheckToken(w, r); err != nil {
|
||||||
|
amw.authMux.ServeHTTP(w, r)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if r.URL.Path == OIDCLogoutPath {
|
||||||
|
auth.LogoutHandler(w, r)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
|
@ -127,6 +127,20 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
|
||||||
}
|
}
|
||||||
args.setDefaults()
|
args.setDefaults()
|
||||||
|
|
||||||
|
mid, setOptErr := middleware.New(args.middlewareOpt)
|
||||||
|
if setOptErr != nil {
|
||||||
|
return nil, setOptErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return newMiddlewaresTest([]*Middleware{mid}, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) {
|
||||||
|
if args == nil {
|
||||||
|
args = new(testArgs)
|
||||||
|
}
|
||||||
|
args.setDefaults()
|
||||||
|
|
||||||
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
|
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
|
||||||
for k, v := range args.headers {
|
for k, v := range args.headers {
|
||||||
req.Header[k] = v
|
req.Header[k] = v
|
||||||
|
@ -139,14 +153,8 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
|
||||||
rr.parent = http.DefaultTransport
|
rr.parent = http.DefaultTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
rp := reverseproxy.NewReverseProxy(middleware.name, args.upstreamURL, rr)
|
rp := reverseproxy.NewReverseProxy("test", args.upstreamURL, rr)
|
||||||
|
patchReverseProxy(rp, middlewares)
|
||||||
mid, setOptErr := middleware.New(args.middlewareOpt)
|
|
||||||
if setOptErr != nil {
|
|
||||||
return nil, setOptErr
|
|
||||||
}
|
|
||||||
|
|
||||||
patchReverseProxy(rp, []*Middleware{mid})
|
|
||||||
rp.ServeHTTP(w, req)
|
rp.ServeHTTP(w, req)
|
||||||
|
|
||||||
resp := w.Result()
|
resp := w.Result()
|
||||||
|
|
|
@ -73,6 +73,16 @@ GoDoxy v0.8.2 expected changes
|
||||||
* Connection #0 to host localhost left intact
|
* Connection #0 to host localhost left intact
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- **Thanks [polds](https://github.com/polds)**
|
||||||
|
Support WebUI authentication via OIDC by setting these environment variables:
|
||||||
|
- `GODOXY_API_USER`
|
||||||
|
- `GODOXY_API_JWT_SECRET`
|
||||||
|
- `GODOXY_OIDC_ISSUER_URL`
|
||||||
|
- `GODOXY_OIDC_CLIENT_ID`
|
||||||
|
- `GODOXY_OIDC_CLIENT_SECRET`
|
||||||
|
- `GODOXY_OIDC_REDIRECT_URL`
|
||||||
|
- `GODOXY_OIDC_SCOPES`
|
||||||
|
|
||||||
- Caddyfile like rules
|
- Caddyfile like rules
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|
Loading…
Add table
Reference in a new issue