mirror of
https://github.com/yusing/godoxy.git
synced 2025-06-01 09:32:35 +02:00
auth code cleanup
This commit is contained in:
parent
bb0ee5d7a9
commit
9aee310844
17 changed files with 394 additions and 350 deletions
15
.env.example
15
.env.example
|
@ -1,20 +1,15 @@
|
|||
# set timezone to get correct log timestamp
|
||||
TZ=ETC/UTC
|
||||
|
||||
# API/WebUI user password login credentials (optional)
|
||||
# These fields are not required for OIDC authentication
|
||||
GODOXY_API_USER=admin
|
||||
GODOXY_API_PASSWORD=password
|
||||
# generate secret with `openssl rand -base64 32`
|
||||
# used for both user password authentication and OIDC
|
||||
GODOXY_API_JWT_SECRET=
|
||||
|
||||
# the JWT token time-to-live
|
||||
GODOXY_API_JWT_TOKEN_TTL=1h
|
||||
|
||||
# API/WebUI login credentials
|
||||
# Important: If using OIDC authentication, the API_USER must match the username
|
||||
# provided by the OIDC provider.
|
||||
GODOXY_API_USER=admin
|
||||
# Password is not required for OIDC authentication
|
||||
GODOXY_API_PASSWORD=password
|
||||
|
||||
# OIDC Configuration (optional)
|
||||
# Uncomment and configure these values to enable OIDC authentication.
|
||||
# GODOXY_OIDC_ISSUER_URL=https://accounts.google.com
|
||||
|
@ -24,6 +19,8 @@ GODOXY_API_PASSWORD=password
|
|||
# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback
|
||||
# Comma-separated list of scopes
|
||||
# GODOXY_OIDC_SCOPES=openid, profile, email
|
||||
# Comma-separated list of allowed users
|
||||
# GODOXY_OIDC_ALLOWED_USERS=user1,user2
|
||||
|
||||
# Proxy listening address
|
||||
GODOXY_HTTP_ADDR=:80
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/yusing/go-proxy/internal"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/query"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
"github.com/yusing/go-proxy/internal/config"
|
||||
|
@ -108,6 +109,10 @@ func main() {
|
|||
return
|
||||
}
|
||||
|
||||
if err := auth.Initialize(); err != nil {
|
||||
logging.Fatal().Err(err).Msg("failed to initialize authentication")
|
||||
}
|
||||
|
||||
cfg.Start()
|
||||
config.WatchChanges()
|
||||
|
||||
|
|
|
@ -1,43 +1,44 @@
|
|||
package api
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||
"github.com/yusing/go-proxy/internal/api/v1/favicon"
|
||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
config "github.com/yusing/go-proxy/internal/config/types"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type ServeMux struct{ *http.ServeMux }
|
||||
|
||||
func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) {
|
||||
mux.ServeMux.HandleFunc(method+" "+endpoint, checkHost(handler))
|
||||
func (mux ServeMux) HandleFunc(methods, endpoint string, handler http.HandlerFunc) {
|
||||
for _, m := range strutils.CommaSeperatedList(methods) {
|
||||
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
|
||||
}
|
||||
}
|
||||
|
||||
func NewHandler(cfg config.ConfigInstance) http.Handler {
|
||||
mux := ServeMux{http.NewServeMux()}
|
||||
mux.HandleFunc("GET", "/v1", v1.Index)
|
||||
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
||||
mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler)
|
||||
mux.HandleFunc("GET", "/v1/auth/redirect", auth.APIAuthRedirectHandler)
|
||||
mux.HandleFunc("GET", "/v1/auth/callback", auth.APIOIDCCallbackHandler)
|
||||
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
|
||||
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
|
||||
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
|
||||
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List)))
|
||||
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List)))
|
||||
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List)))
|
||||
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent))
|
||||
mux.HandleFunc("POST", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
|
||||
mux.HandleFunc("PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
|
||||
mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
|
||||
mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile)
|
||||
mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats))
|
||||
mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS))
|
||||
mux.HandleFunc("GET", "/v1/favicon/{alias}", auth.RequireAuth(favicon.GetFavIcon))
|
||||
|
||||
defaultAuth := auth.GetDefaultAuth()
|
||||
if defaultAuth != nil {
|
||||
mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage)
|
||||
mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler)
|
||||
mux.HandleFunc("GET,POST", "/v1/auth/logout", auth.LogoutCallbackHandler(defaultAuth))
|
||||
}
|
||||
return mux
|
||||
}
|
||||
|
||||
|
@ -46,20 +47,3 @@ func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w
|
|||
handler(cfg, w, r)
|
||||
}
|
||||
}
|
||||
|
||||
// allow only requests to API server with localhost.
|
||||
func checkHost(f http.HandlerFunc) http.HandlerFunc {
|
||||
if common.IsDebug {
|
||||
return f
|
||||
}
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
|
||||
LogWarn(r).Msgf("blocked API request from %s", host)
|
||||
http.Error(w, "forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
LogDebug(r).Interface("headers", r.Header).Msg("API request")
|
||||
f(w, r)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,47 +1,35 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/logging"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
type (
|
||||
Credentials struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
Claims struct {
|
||||
Username string `json:"username"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
)
|
||||
var defaultAuth Provider
|
||||
|
||||
// init sets up authentication providers.
|
||||
func init() {
|
||||
// Initialize sets up authentication providers.
|
||||
func Initialize() error {
|
||||
if !IsEnabled() {
|
||||
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
// Initialize OIDC if configured.
|
||||
if common.OIDCIssuerURL != "" {
|
||||
if err := initOIDC(
|
||||
common.OIDCIssuerURL,
|
||||
common.OIDCClientID,
|
||||
common.OIDCClientSecret,
|
||||
common.OIDCRedirectURL,
|
||||
); err != nil {
|
||||
logging.Fatal().Err(err).Msg("failed to initialize OIDC provider")
|
||||
}
|
||||
defaultAuth, err = NewOIDCProviderFromEnv()
|
||||
} else {
|
||||
defaultAuth, err = NewUserPassAuthFromEnv()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func GetDefaultAuth() Provider {
|
||||
return defaultAuth
|
||||
}
|
||||
|
||||
func IsEnabled() bool {
|
||||
|
@ -52,85 +40,10 @@ func IsOIDCEnabled() bool {
|
|||
return common.OIDCIssuerURL != ""
|
||||
}
|
||||
|
||||
// cookieFQDN returns the fully qualified domain name of the request host
|
||||
// with subdomain stripped.
|
||||
//
|
||||
// If the request host does not have a subdomain,
|
||||
// an empty string is returned
|
||||
//
|
||||
// "abc.example.com" -> "example.com"
|
||||
// "example.com" -> ""
|
||||
func cookieFQDN(r *http.Request) string {
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
}
|
||||
parts := strutils.SplitRune(host, '.')
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
parts[0] = ""
|
||||
return strutils.JoinRune(parts, '.')
|
||||
}
|
||||
|
||||
// APIAuthRedirectHandler handles API redirect to login page or OIDC login base on configuration.
|
||||
func APIAuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch {
|
||||
case apiOAuth != nil:
|
||||
apiOAuth.RedirectOIDC(w, r)
|
||||
return
|
||||
case common.APIJWTSecret != nil:
|
||||
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
||||
return
|
||||
default:
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
}
|
||||
}
|
||||
|
||||
func setAuthenticatedCookie(w http.ResponseWriter, r *http.Request, username string) error {
|
||||
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
|
||||
claim := &Claims{
|
||||
Username: username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
||||
},
|
||||
}
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claim)
|
||||
tokenStr, err := token.SignedString(common.APIJWTSecret)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: CookieToken,
|
||||
Value: tokenStr,
|
||||
Expires: expiresAt,
|
||||
Domain: cookieFQDN(r),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Path: "/",
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogoutHandler clear authentication cookie and redirect to login page.
|
||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: CookieToken,
|
||||
MaxAge: -1,
|
||||
Domain: cookieFQDN(r),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Path: "/",
|
||||
})
|
||||
APIAuthRedirectHandler(w, r)
|
||||
}
|
||||
|
||||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
if IsEnabled() {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := CheckToken(w, r); err != nil {
|
||||
if err := defaultAuth.CheckToken(w, r); err != nil {
|
||||
U.RespondError(w, err, http.StatusUnauthorized)
|
||||
} else {
|
||||
next(w, r)
|
||||
|
@ -139,30 +52,3 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
|||
}
|
||||
return next
|
||||
}
|
||||
|
||||
func CheckToken(w http.ResponseWriter, r *http.Request) error {
|
||||
tokenCookie, err := r.Cookie(CookieToken)
|
||||
if err != nil {
|
||||
return E.New("missing token")
|
||||
}
|
||||
var claims Claims
|
||||
token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return common.APIJWTSecret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch {
|
||||
case !token.Valid:
|
||||
return E.New("invalid token")
|
||||
case claims.Username != common.APIUser:
|
||||
return E.New("username mismatch").Subject(claims.Username)
|
||||
case claims.ExpiresAt.Before(time.Now()):
|
||||
return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
package auth
|
||||
|
||||
const (
|
||||
CookieToken = "godoxy_token"
|
||||
CookieOauthState = "godoxy_oauth_state"
|
||||
)
|
|
@ -2,8 +2,13 @@ package auth
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
|
@ -17,26 +22,17 @@ type OIDCProvider struct {
|
|||
oauthConfig *oauth2.Config
|
||||
oidcProvider *oidc.Provider
|
||||
oidcVerifier *oidc.IDTokenVerifier
|
||||
allowedUsers []string
|
||||
overrideHost bool
|
||||
}
|
||||
|
||||
var (
|
||||
apiOAuth *OIDCProvider
|
||||
APIOIDCCallbackHandler http.HandlerFunc
|
||||
)
|
||||
const CookieOauthState = "godoxy_oidc_state"
|
||||
|
||||
// initOIDC initializes the OIDC provider.
|
||||
func initOIDC(issuerURL, clientID, clientSecret, redirectURL string) (err error) {
|
||||
if issuerURL == "" {
|
||||
return nil // OIDC not configured
|
||||
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) {
|
||||
if len(allowedUsers) == 0 {
|
||||
return nil, errors.New("OIDC allowed users must not be empty")
|
||||
}
|
||||
|
||||
apiOAuth, err = NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL)
|
||||
APIOIDCCallbackHandler = apiOAuth.OIDCCallbackHandler
|
||||
return
|
||||
}
|
||||
|
||||
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string) (*OIDCProvider, error) {
|
||||
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
|
||||
|
@ -54,30 +50,82 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string) (*OI
|
|||
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: clientID,
|
||||
}),
|
||||
allowedUsers: allowedUsers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewOIDCProviderFromEnv(redirectURL string) (*OIDCProvider, error) {
|
||||
// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables.
|
||||
func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
|
||||
return NewOIDCProvider(
|
||||
common.OIDCIssuerURL,
|
||||
common.OIDCClientID,
|
||||
common.OIDCClientSecret,
|
||||
redirectURL,
|
||||
common.OIDCRedirectURL,
|
||||
common.OIDCAllowedUsers,
|
||||
)
|
||||
}
|
||||
|
||||
func (provider *OIDCProvider) SetOverrideHostEnabled(enabled bool) {
|
||||
provider.overrideHost = enabled
|
||||
func (auth *OIDCProvider) TokenCookieName() string {
|
||||
return "godoxy_oidc_token"
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) SetOverrideHostEnabled(enabled bool) {
|
||||
auth.overrideHost = enabled
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
||||
auth.allowedUsers = users
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) CheckToken(w http.ResponseWriter, r *http.Request) error {
|
||||
token, err := r.Cookie(auth.TokenCookieName())
|
||||
if err != nil {
|
||||
return ErrMissingToken
|
||||
}
|
||||
|
||||
// checks for Expiry, Audience == ClientID, Issuer, etc.
|
||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to verify ID token: %w", err)
|
||||
}
|
||||
|
||||
if len(idToken.Audience) == 0 {
|
||||
return ErrInvalidToken
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"preferred_username"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
return fmt.Errorf("failed to parse claims: %w", err)
|
||||
}
|
||||
|
||||
if !slices.Contains(auth.allowedUsers, claims.Username) {
|
||||
return ErrUserNotAllowed.Subject(claims.Username)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// generateState generates a random string for ODIC state.
|
||||
const odicStateLength = 32
|
||||
|
||||
func generateState() (string, error) {
|
||||
b := make([]byte, odicStateLength)
|
||||
_, err := rand.Read(b)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b)[:odicStateLength], nil
|
||||
}
|
||||
|
||||
// RedirectOIDC initiates the OIDC login flow.
|
||||
func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Request) {
|
||||
if provider == nil {
|
||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
||||
func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := generateState()
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
state := common.GenerateRandomString(32)
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: CookieOauthState,
|
||||
Value: state,
|
||||
|
@ -88,8 +136,8 @@ func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Reques
|
|||
Path: "/",
|
||||
})
|
||||
|
||||
redirURL := provider.oauthConfig.AuthCodeURL(state)
|
||||
if provider.overrideHost {
|
||||
redirURL := auth.oauthConfig.AuthCodeURL(state)
|
||||
if auth.overrideHost {
|
||||
u, err := r.URL.Parse(redirURL)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
|
@ -104,20 +152,10 @@ func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Reques
|
|||
}
|
||||
|
||||
// OIDCCallbackHandler handles the OIDC callback.
|
||||
func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
if provider == nil {
|
||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
||||
return
|
||||
}
|
||||
|
||||
func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
// For testing purposes, skip provider verification
|
||||
if common.IsTest {
|
||||
handleTestCallback(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if provider.oidcProvider == nil {
|
||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
||||
auth.handleTestCallback(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -133,7 +171,7 @@ func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http
|
|||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
oauth2Token, err := provider.oauthConfig.Exchange(r.Context(), code)
|
||||
oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
|
||||
return
|
||||
|
@ -145,32 +183,20 @@ func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http
|
|||
return
|
||||
}
|
||||
|
||||
idToken, err := provider.oidcVerifier.Verify(r.Context(), rawIDToken)
|
||||
idToken, err := auth.oidcVerifier.Verify(r.Context(), rawIDToken)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var claims struct {
|
||||
Email string `json:"email"`
|
||||
Username string `json:"preferred_username"`
|
||||
}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
U.HandleErr(w, r, fmt.Errorf("failed to parse claims: %w", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if err := setAuthenticatedCookie(w, r, claims.Username); err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
setTokenCookie(w, r, auth.TokenCookieName(), rawIDToken, time.Until(idToken.Expiry))
|
||||
|
||||
// Redirect to home page
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
// handleTestCallback handles OIDC callback in test environment.
|
||||
func handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
||||
func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
||||
state, err := r.Cookie(CookieOauthState)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
|
||||
|
@ -183,10 +209,7 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
// Create test JWT token
|
||||
if err := setAuthenticatedCookie(w, r, "test-user"); err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
setTokenCookie(w, r, auth.TokenCookieName(), "test", time.Hour)
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
@ -14,7 +15,8 @@ import (
|
|||
func setupMockOIDC(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
apiOAuth = &OIDCProvider{
|
||||
provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO())
|
||||
defaultAuth = &OIDCProvider{
|
||||
oauthConfig: &oauth2.Config{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
|
@ -25,53 +27,42 @@ func setupMockOIDC(t *testing.T) {
|
|||
},
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
},
|
||||
oidcProvider: provider,
|
||||
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||
ClientID: "test-client",
|
||||
}),
|
||||
allowedUsers: []string{"test-user"},
|
||||
}
|
||||
}
|
||||
|
||||
func cleanup() {
|
||||
apiOAuth = nil
|
||||
defaultAuth = nil
|
||||
}
|
||||
|
||||
func TestOIDCLoginHandler(t *testing.T) {
|
||||
// Setup
|
||||
common.APIJWTSecret = []byte("test-secret")
|
||||
common.IsTest = true
|
||||
t.Cleanup(func() {
|
||||
cleanup()
|
||||
common.IsTest = false
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
setupMockOIDC(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
configureOAuth bool
|
||||
wantStatus int
|
||||
wantRedirect bool
|
||||
name string
|
||||
wantStatus int
|
||||
wantRedirect bool
|
||||
}{
|
||||
{
|
||||
name: "Success - Redirects to provider",
|
||||
configureOAuth: true,
|
||||
wantStatus: http.StatusTemporaryRedirect,
|
||||
wantRedirect: true,
|
||||
},
|
||||
{
|
||||
name: "Failure - OIDC not configured",
|
||||
configureOAuth: false,
|
||||
wantStatus: http.StatusNotImplemented,
|
||||
wantRedirect: false,
|
||||
name: "Success - Redirects to provider",
|
||||
wantStatus: http.StatusTemporaryRedirect,
|
||||
wantRedirect: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if !tt.configureOAuth {
|
||||
apiOAuth = nil
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
apiOAuth.RedirectOIDC(w, req)
|
||||
defaultAuth.RedirectLoginPage(w, req)
|
||||
|
||||
if got := w.Code; got != tt.wantStatus {
|
||||
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
|
||||
|
@ -94,65 +85,45 @@ func TestOIDCLoginHandler(t *testing.T) {
|
|||
func TestOIDCCallbackHandler(t *testing.T) {
|
||||
// Setup
|
||||
common.APIJWTSecret = []byte("test-secret")
|
||||
common.IsTest = true
|
||||
t.Cleanup(func() {
|
||||
cleanup()
|
||||
common.IsTest = false
|
||||
})
|
||||
t.Cleanup(cleanup)
|
||||
tests := []struct {
|
||||
name string
|
||||
configureOAuth bool
|
||||
state string
|
||||
code string
|
||||
setupMocks func()
|
||||
wantStatus int
|
||||
name string
|
||||
state string
|
||||
code string
|
||||
setupMocks bool
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "Success - Valid callback",
|
||||
configureOAuth: true,
|
||||
state: "valid-state",
|
||||
code: "valid-code",
|
||||
setupMocks: func() {
|
||||
setupMockOIDC(t)
|
||||
},
|
||||
name: "Success - Valid callback",
|
||||
state: "valid-state",
|
||||
code: "valid-code",
|
||||
setupMocks: true,
|
||||
wantStatus: http.StatusTemporaryRedirect,
|
||||
},
|
||||
{
|
||||
name: "Failure - OIDC not configured",
|
||||
configureOAuth: false,
|
||||
wantStatus: http.StatusNotImplemented,
|
||||
},
|
||||
{
|
||||
name: "Failure - Missing state",
|
||||
configureOAuth: true,
|
||||
code: "valid-code",
|
||||
setupMocks: func() {
|
||||
setupMockOIDC(t)
|
||||
},
|
||||
name: "Failure - Missing state",
|
||||
code: "valid-code",
|
||||
setupMocks: true,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if tt.setupMocks != nil {
|
||||
tt.setupMocks()
|
||||
}
|
||||
|
||||
if !tt.configureOAuth {
|
||||
apiOAuth = nil
|
||||
if tt.setupMocks {
|
||||
setupMockOIDC(t)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil)
|
||||
if tt.state != "" {
|
||||
req.AddCookie(&http.Cookie{
|
||||
Name: "oauth_state",
|
||||
Name: CookieOauthState,
|
||||
Value: tt.state,
|
||||
})
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
|
||||
apiOAuth.OIDCCallbackHandler(w, req)
|
||||
defaultAuth.LoginCallbackHandler(w, req)
|
||||
|
||||
if got := w.Code; got != tt.wantStatus {
|
||||
t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus)
|
||||
|
@ -169,32 +140,48 @@ func TestOIDCCallbackHandler(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInitOIDC(t *testing.T) {
|
||||
common.IsTest = true
|
||||
t.Cleanup(func() {
|
||||
common.IsTest = false
|
||||
})
|
||||
tests := []struct {
|
||||
name string
|
||||
issuerURL string
|
||||
clientID string
|
||||
clientSecret string
|
||||
redirectURL string
|
||||
allowedUsers []string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "Success - Empty configuration",
|
||||
name: "Fail - Empty configuration",
|
||||
issuerURL: "",
|
||||
clientID: "",
|
||||
clientSecret: "",
|
||||
redirectURL: "",
|
||||
wantErr: false,
|
||||
allowedUsers: nil,
|
||||
wantErr: true,
|
||||
},
|
||||
// {
|
||||
// name: "Success - Valid configuration",
|
||||
// issuerURL: "https://example.com",
|
||||
// clientID: "client_id",
|
||||
// clientSecret: "client_secret",
|
||||
// redirectURL: "https://example.com/callback",
|
||||
// allowedUsers: []string{"user1", "user2"},
|
||||
// wantErr: false,
|
||||
// },
|
||||
{
|
||||
name: "Fail - No allowed users",
|
||||
issuerURL: "https://example.com",
|
||||
clientID: "client_id",
|
||||
clientSecret: "client_secret",
|
||||
redirectURL: "https://example.com/callback",
|
||||
allowedUsers: []string{},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Cleanup(cleanup)
|
||||
err := initOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL)
|
||||
_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
|
12
internal/api/v1/auth/provider.go
Normal file
12
internal/api/v1/auth/provider.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Provider interface {
|
||||
TokenCookieName() string
|
||||
CheckToken(w http.ResponseWriter, r *http.Request) error
|
||||
RedirectLoginPage(w http.ResponseWriter, r *http.Request)
|
||||
LoginCallbackHandler(w http.ResponseWriter, r *http.Request)
|
||||
}
|
|
@ -1,13 +1,17 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||
"github.com/yusing/go-proxy/internal/common"
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -15,31 +19,120 @@ var (
|
|||
ErrInvalidPassword = E.New("invalid password")
|
||||
)
|
||||
|
||||
func validatePassword(cred *Credentials) error {
|
||||
if cred.Username != common.APIUser {
|
||||
return ErrInvalidUsername.Subject(cred.Username)
|
||||
type (
|
||||
UserPassAuth struct {
|
||||
username string
|
||||
pwdHash []byte
|
||||
secret []byte
|
||||
tokenTTL time.Duration
|
||||
}
|
||||
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) {
|
||||
return ErrInvalidPassword.Subject(cred.Password)
|
||||
UserPassClaims struct {
|
||||
Username string `json:"username"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
)
|
||||
|
||||
func NewUserPassAuth(username, password string, secret []byte, tokenTTL time.Duration) (*UserPassAuth, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &UserPassAuth{
|
||||
username: username,
|
||||
pwdHash: hash,
|
||||
secret: secret,
|
||||
tokenTTL: tokenTTL,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewUserPassAuthFromEnv() (*UserPassAuth, error) {
|
||||
return NewUserPassAuth(
|
||||
common.APIUser,
|
||||
common.APIPassword,
|
||||
common.APIJWTSecret,
|
||||
common.APIJWTTokenTTL,
|
||||
)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) TokenCookieName() string {
|
||||
return "godoxy_token"
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (token string, err error) {
|
||||
claim := &UserPassClaims{
|
||||
Username: auth.username,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(time.Now().Add(auth.tokenTTL)),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS512, claim)
|
||||
token, err = tok.SignedString(auth.secret)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) CheckToken(w http.ResponseWriter, r *http.Request) error {
|
||||
jwtCookie, err := r.Cookie(auth.TokenCookieName())
|
||||
if err != nil {
|
||||
return ErrMissingToken
|
||||
}
|
||||
var claims UserPassClaims
|
||||
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||
}
|
||||
return auth.secret, nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch {
|
||||
case !token.Valid:
|
||||
return ErrInvalidToken
|
||||
case claims.Username != auth.username:
|
||||
return ErrUserNotAllowed.Subject(claims.Username)
|
||||
case claims.ExpiresAt.Before(time.Now()):
|
||||
return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserPassLoginHandler handles user login.
|
||||
func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var creds Credentials
|
||||
func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||
var creds struct {
|
||||
User string `json:"username"`
|
||||
Pass string `json:"password"`
|
||||
}
|
||||
err := json.NewDecoder(r.Body).Decode(&creds)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if err := validatePassword(&creds); err != nil {
|
||||
if err := auth.validatePassword(creds.User, creds.Pass); err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if err := setAuthenticatedCookie(w, r, creds.Username); err != nil {
|
||||
token, err := auth.CreateToken(w, r)
|
||||
if err != nil {
|
||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
setTokenCookie(w, r, auth.TokenCookieName(), token, auth.tokenTTL)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
func (auth *UserPassAuth) validatePassword(user, pass string) error {
|
||||
if user != auth.username {
|
||||
return ErrInvalidUsername.Subject(user)
|
||||
}
|
||||
if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil {
|
||||
return ErrInvalidPassword.Subject(pass)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
70
internal/api/v1/auth/utils.go
Normal file
70
internal/api/v1/auth/utils.go
Normal file
|
@ -0,0 +1,70 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
E "github.com/yusing/go-proxy/internal/error"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrMissingToken = E.New("missing token")
|
||||
ErrInvalidToken = E.New("invalid token")
|
||||
ErrUserNotAllowed = E.New("user not allowed")
|
||||
)
|
||||
|
||||
// cookieFQDN returns the fully qualified domain name of the request host
|
||||
// with subdomain stripped.
|
||||
//
|
||||
// If the request host does not have a subdomain,
|
||||
// an empty string is returned
|
||||
//
|
||||
// "abc.example.com" -> "example.com"
|
||||
// "example.com" -> ""
|
||||
func cookieFQDN(r *http.Request) string {
|
||||
host, _, err := net.SplitHostPort(r.Host)
|
||||
if err != nil {
|
||||
host = r.Host
|
||||
}
|
||||
parts := strutils.SplitRune(host, '.')
|
||||
if len(parts) < 2 {
|
||||
return ""
|
||||
}
|
||||
parts[0] = ""
|
||||
return strutils.JoinRune(parts, '.')
|
||||
}
|
||||
|
||||
func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string, ttl time.Duration) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: value,
|
||||
MaxAge: int(ttl.Seconds()),
|
||||
Domain: cookieFQDN(r),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
|
||||
func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
MaxAge: -1,
|
||||
Domain: cookieFQDN(r),
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
Path: "/",
|
||||
})
|
||||
}
|
||||
|
||||
func LogoutCallbackHandler(auth Provider) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
clearTokenCookie(w, r, auth.TokenCookieName())
|
||||
auth.RedirectLoginPage(w, r)
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
|||
return logging.WithLevel(level).
|
||||
Str("module", "api").
|
||||
Str("remote", r.RemoteAddr).
|
||||
Str("host", r.Host).
|
||||
Str("uri", r.Method+" "+r.RequestURI)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,18 +1,11 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func HashPassword(pwd string) []byte {
|
||||
h := sha512.New()
|
||||
h.Write([]byte(pwd))
|
||||
return h.Sum(nil)
|
||||
}
|
||||
|
||||
func decodeJWTKey(key string) []byte {
|
||||
if key == "" {
|
||||
return nil
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -40,10 +41,10 @@ var (
|
|||
MetricsHTTPURL = GetAddrEnv("PROMETHEUS_ADDR", "", "http")
|
||||
PrometheusEnabled = MetricsHTTPURL != ""
|
||||
|
||||
APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", ""))
|
||||
APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour)
|
||||
APIUser = GetEnvString("API_USER", "admin")
|
||||
APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password"))
|
||||
APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", ""))
|
||||
APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour)
|
||||
APIUser = GetEnvString("API_USER", "admin")
|
||||
APIPassword = GetEnvString("API_PASSWORD", "password")
|
||||
|
||||
// OIDC Configuration.
|
||||
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
|
||||
|
@ -51,6 +52,7 @@ var (
|
|||
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
|
||||
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
|
||||
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
|
||||
OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "")
|
||||
)
|
||||
|
||||
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {
|
||||
|
@ -102,3 +104,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
|
|||
func GetDurationEnv(key string, defaultValue time.Duration) time.Duration {
|
||||
return GetEnv(key, defaultValue, time.ParseDuration)
|
||||
}
|
||||
|
||||
func GetCommaSepEnv(key string, defaultValue string) []string {
|
||||
return strutils.CommaSeperatedList(GetEnvString(key, defaultValue))
|
||||
}
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// GenerateRandomString generates a random string of specified length.
|
||||
func GenerateRandomString(length int) string {
|
||||
b := make([]byte, length)
|
||||
rand.Read(b)
|
||||
return base64.URLEncoding.EncodeToString(b)[:length]
|
||||
}
|
|
@ -8,43 +8,47 @@ import (
|
|||
)
|
||||
|
||||
type oidcMiddleware struct {
|
||||
oauth *auth.OIDCProvider
|
||||
authMux *http.ServeMux
|
||||
AllowedUsers []string
|
||||
|
||||
auth auth.Provider
|
||||
authMux *http.ServeMux
|
||||
logoutHandler http.HandlerFunc
|
||||
}
|
||||
|
||||
var OIDC = NewMiddleware[oidcMiddleware]()
|
||||
|
||||
const (
|
||||
OIDCMiddlewareCallbackPath = "/godoxy-auth-oidc/callback"
|
||||
OIDCLogoutPath = "/logout"
|
||||
OIDCMiddlewareCallbackPath = "/auth/callback"
|
||||
OIDCLogoutPath = "/auth/logout"
|
||||
)
|
||||
|
||||
func (amw *oidcMiddleware) finalize() error {
|
||||
if !auth.IsOIDCEnabled() {
|
||||
return E.New("OIDC not enabled but Auth middleware is used")
|
||||
}
|
||||
provider, err := auth.NewOIDCProviderFromEnv(OIDCMiddlewareCallbackPath)
|
||||
authProvider, err := auth.NewOIDCProviderFromEnv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
provider.SetOverrideHostEnabled(true)
|
||||
amw.oauth = provider
|
||||
authProvider.SetOverrideHostEnabled(true)
|
||||
amw.authMux = http.NewServeMux()
|
||||
amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, provider.OIDCCallbackHandler)
|
||||
amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||
amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
})
|
||||
amw.authMux.HandleFunc("/", provider.RedirectOIDC)
|
||||
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
||||
amw.logoutHandler = auth.LogoutCallbackHandler(authProvider)
|
||||
amw.auth = authProvider
|
||||
return nil
|
||||
}
|
||||
|
||||
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||
if err, _ := auth.CheckToken(w, r); err != nil {
|
||||
if err := amw.auth.CheckToken(w, r); err != nil {
|
||||
amw.authMux.ServeHTTP(w, r)
|
||||
return false
|
||||
}
|
||||
if r.URL.Path == OIDCLogoutPath {
|
||||
auth.LogoutHandler(w, r)
|
||||
amw.logoutHandler(w, r)
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
|
|
@ -10,6 +10,9 @@ import (
|
|||
// CommaSeperatedList returns a list of strings split by commas,
|
||||
// then trim spaces from each element.
|
||||
func CommaSeperatedList(s string) []string {
|
||||
if s == "" {
|
||||
return []string{}
|
||||
}
|
||||
res := SplitComma(s)
|
||||
for i, part := range res {
|
||||
res[i] = strings.TrimSpace(part)
|
||||
|
|
|
@ -75,13 +75,12 @@ GoDoxy v0.8.2 expected changes
|
|||
|
||||
- **Thanks [polds](https://github.com/polds)**
|
||||
Support WebUI authentication via OIDC by setting these environment variables:
|
||||
- `GODOXY_API_USER`
|
||||
- `GODOXY_API_JWT_SECRET`
|
||||
- `GODOXY_OIDC_ISSUER_URL`
|
||||
- `GODOXY_OIDC_CLIENT_ID`
|
||||
- `GODOXY_OIDC_CLIENT_SECRET`
|
||||
- `GODOXY_OIDC_REDIRECT_URL`
|
||||
- `GODOXY_OIDC_SCOPES`
|
||||
- `GODOXY_OIDC_SCOPES` _(optional)_
|
||||
- `GODOXY_OIDC_ALLOWED_USERS`
|
||||
|
||||
- Caddyfile like rules
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue