mirror of
https://github.com/yusing/godoxy.git
synced 2025-07-21 20:04:03 +02:00
Feat/OIDC middleware (#50)
* implement OIDC middleware * auth code cleanup * allow override allowed_user in middleware, fix typos * fix tests and callbackURL * update next release docs * fix OIDC middleware not working with Authentik * feat: add groups support for OIDC claims (#41) Allow users to specify allowed groups in the env and use it to inspect the claims. This performs a logical AND of users and groups (additive). * merge feat/oidc-middleware (#49) * api: enrich provider statistifcs * fix: docker monitor now uses container status * Feat/auto schemas (#48) * use auto generated schemas * go version bump and dependencies upgrade * clarify some error messages --------- Co-authored-by: yusing <yusing@6uo.me> * cleanup some loadbalancer code * api: cleanup websocket code * api: add /v1/health/ws for health bubbles on dashboard * feat: experimental memory logger and logs api for WebUI --------- Co-authored-by: yusing <yusing@6uo.me> --------- Co-authored-by: yusing <yusing@6uo.me> Co-authored-by: Peter Olds <peter@olds.co>
This commit is contained in:
parent
0fad7b3411
commit
fb0dc7dea0
26 changed files with 1168 additions and 368 deletions
30
.env.example
30
.env.example
|
@ -1,17 +1,14 @@
|
||||||
# set timezone to get correct log timestamp
|
# set timezone to get correct log timestamp
|
||||||
TZ=ETC/UTC
|
TZ=ETC/UTC
|
||||||
|
|
||||||
# generate secret with `openssl rand -base64 32`
|
# API/WebUI user password login credentials (optional)
|
||||||
GODOXY_API_JWT_SECRET=
|
# These fields are not required for OIDC authentication
|
||||||
|
|
||||||
# the JWT token time-to-live
|
|
||||||
GODOXY_API_JWT_TOKEN_TTL=1h
|
|
||||||
|
|
||||||
# API/WebUI login credentials
|
|
||||||
# Important: If using OIDC authentication, the API_USER must match the username
|
|
||||||
# provided by the OIDC provider.
|
|
||||||
GODOXY_API_USER=admin
|
GODOXY_API_USER=admin
|
||||||
GODOXY_API_PASSWORD=password
|
GODOXY_API_PASSWORD=password
|
||||||
|
# generate secret with `openssl rand -base64 32`
|
||||||
|
GODOXY_API_JWT_SECRET=
|
||||||
|
# the JWT token time-to-live
|
||||||
|
GODOXY_API_JWT_TOKEN_TTL=1h
|
||||||
|
|
||||||
# OIDC Configuration (optional)
|
# OIDC Configuration (optional)
|
||||||
# Uncomment and configure these values to enable OIDC authentication.
|
# Uncomment and configure these values to enable OIDC authentication.
|
||||||
|
@ -22,6 +19,21 @@ GODOXY_API_PASSWORD=password
|
||||||
# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback
|
# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback
|
||||||
# Comma-separated list of scopes
|
# Comma-separated list of scopes
|
||||||
# GODOXY_OIDC_SCOPES=openid, profile, email
|
# GODOXY_OIDC_SCOPES=openid, profile, email
|
||||||
|
#
|
||||||
|
# User definitions: Uncomment and configure these values to restrict access to specific users or groups.
|
||||||
|
# These two fields act as a logical AND operator. For example, given the following membership:
|
||||||
|
# user1, group1
|
||||||
|
# user2, group1
|
||||||
|
# user3, group2
|
||||||
|
# user1, group2
|
||||||
|
# You can allow access to user3 AND all users of group1 by providing:
|
||||||
|
# # GODOXY_OIDC_ALLOWED_USERS=user3
|
||||||
|
# # GODOXY_OIDC_ALLOWED_GROUPS=group1
|
||||||
|
#
|
||||||
|
# Comma-separated list of allowed users.
|
||||||
|
# GODOXY_OIDC_ALLOWED_USERS=user1,user2
|
||||||
|
# Optional: Comma-separated list of allowed groups.
|
||||||
|
# GODOXY_OIDC_ALLOWED_GROUPS=group1,group2
|
||||||
|
|
||||||
# Proxy listening address
|
# Proxy listening address
|
||||||
GODOXY_HTTP_ADDR=:80
|
GODOXY_HTTP_ADDR=:80
|
||||||
|
|
2
Makefile
2
Makefile
|
@ -39,7 +39,7 @@ profile:
|
||||||
|
|
||||||
run: build
|
run: build
|
||||||
sudo setcap CAP_NET_BIND_SERVICE=+eip bin/godoxy
|
sudo setcap CAP_NET_BIND_SERVICE=+eip bin/godoxy
|
||||||
bin/godoxy
|
[ -f .env ] && godotenv -f .env bin/godoxy || bin/godoxy
|
||||||
|
|
||||||
mtrace:
|
mtrace:
|
||||||
bin/godoxy debug-ls-mtrace > mtrace.json
|
bin/godoxy debug-ls-mtrace > mtrace.json
|
||||||
|
|
13
cmd/main.go
13
cmd/main.go
|
@ -117,18 +117,13 @@ func main() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := auth.Initialize(); err != nil {
|
||||||
|
logging.Fatal().Err(err).Msg("failed to initialize authentication")
|
||||||
|
}
|
||||||
|
|
||||||
cfg.Start()
|
cfg.Start()
|
||||||
config.WatchChanges()
|
config.WatchChanges()
|
||||||
|
|
||||||
if !auth.IsEnabled() {
|
|
||||||
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
|
|
||||||
} else {
|
|
||||||
// Initialize authentication providers
|
|
||||||
if err := auth.Initialize(); err != nil {
|
|
||||||
logging.Fatal().Err(err).Msg("Failed to initialize authentication providers")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sig := make(chan os.Signal, 1)
|
sig := make(chan os.Signal, 1)
|
||||||
signal.Notify(sig, syscall.SIGINT)
|
signal.Notify(sig, syscall.SIGINT)
|
||||||
signal.Notify(sig, syscall.SIGTERM)
|
signal.Notify(sig, syscall.SIGTERM)
|
||||||
|
|
|
@ -1,45 +1,46 @@
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
v1 "github.com/yusing/go-proxy/internal/api/v1"
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||||
"github.com/yusing/go-proxy/internal/api/v1/favicon"
|
"github.com/yusing/go-proxy/internal/api/v1/favicon"
|
||||||
. "github.com/yusing/go-proxy/internal/api/v1/utils"
|
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
|
||||||
config "github.com/yusing/go-proxy/internal/config/types"
|
config "github.com/yusing/go-proxy/internal/config/types"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ServeMux struct{ *http.ServeMux }
|
type ServeMux struct{ *http.ServeMux }
|
||||||
|
|
||||||
func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) {
|
func (mux ServeMux) HandleFunc(methods, endpoint string, handler http.HandlerFunc) {
|
||||||
mux.ServeMux.HandleFunc(method+" "+endpoint, checkHost(handler))
|
for _, m := range strutils.CommaSeperatedList(methods) {
|
||||||
|
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHandler(cfg config.ConfigInstance) http.Handler {
|
func NewHandler(cfg config.ConfigInstance) http.Handler {
|
||||||
mux := ServeMux{http.NewServeMux()}
|
mux := ServeMux{http.NewServeMux()}
|
||||||
mux.HandleFunc("GET", "/v1", v1.Index)
|
mux.HandleFunc("GET", "/v1", v1.Index)
|
||||||
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
|
||||||
mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler)
|
|
||||||
mux.HandleFunc("GET", "/v1/auth/redirect", auth.AuthRedirectHandler)
|
|
||||||
mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler)
|
|
||||||
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
|
|
||||||
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
|
|
||||||
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
|
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
|
||||||
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List)))
|
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List)))
|
||||||
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List)))
|
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List)))
|
||||||
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List)))
|
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List)))
|
||||||
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent))
|
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent))
|
||||||
mux.HandleFunc("POST", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
|
mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
|
||||||
mux.HandleFunc("PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
|
|
||||||
mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile)
|
mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile)
|
||||||
mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats))
|
mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats))
|
||||||
mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS))
|
mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS))
|
||||||
mux.HandleFunc("GET", "/v1/health/ws", useCfg(cfg, v1.HealthWS))
|
mux.HandleFunc("GET", "/v1/health/ws", useCfg(cfg, v1.HealthWS))
|
||||||
mux.HandleFunc("GET", "/v1/logs/ws", useCfg(cfg, v1.LogsWS()))
|
mux.HandleFunc("GET", "/v1/logs/ws", useCfg(cfg, v1.LogsWS()))
|
||||||
mux.HandleFunc("GET", "/v1/favicon/{alias}", auth.RequireAuth(favicon.GetFavIcon))
|
mux.HandleFunc("GET", "/v1/favicon/{alias}", auth.RequireAuth(favicon.GetFavIcon))
|
||||||
|
|
||||||
|
defaultAuth := auth.GetDefaultAuth()
|
||||||
|
if defaultAuth != nil {
|
||||||
|
mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage)
|
||||||
|
mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler)
|
||||||
|
mux.HandleFunc("GET,POST", "/v1/auth/logout", auth.LogoutCallbackHandler(defaultAuth))
|
||||||
|
}
|
||||||
return mux
|
return mux
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,20 +49,3 @@ func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w
|
||||||
handler(cfg, w, r)
|
handler(cfg, w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// allow only requests to API server with localhost.
|
|
||||||
func checkHost(f http.HandlerFunc) http.HandlerFunc {
|
|
||||||
if common.IsDebug {
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
|
||||||
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
|
|
||||||
LogWarn(r).Msgf("blocked API request from %s", host)
|
|
||||||
http.Error(w, "forbidden", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
LogDebug(r).Interface("headers", r.Header).Msg("API request")
|
|
||||||
f(w, r)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,139 +1,54 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v5"
|
|
||||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
"github.com/yusing/go-proxy/internal/logging"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type (
|
var defaultAuth Provider
|
||||||
Credentials struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
Password string `json:"password"`
|
|
||||||
}
|
|
||||||
Claims struct {
|
|
||||||
Username string `json:"username"`
|
|
||||||
jwt.RegisteredClaims
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
// Initialize sets up authentication providers.
|
// Initialize sets up authentication providers.
|
||||||
func Initialize() error {
|
func Initialize() error {
|
||||||
|
if !IsEnabled() {
|
||||||
|
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
// Initialize OIDC if configured.
|
// Initialize OIDC if configured.
|
||||||
if common.OIDCIssuerURL != "" {
|
if common.OIDCIssuerURL != "" {
|
||||||
return InitOIDC(
|
defaultAuth, err = NewOIDCProviderFromEnv()
|
||||||
common.OIDCIssuerURL,
|
} else {
|
||||||
common.OIDCClientID,
|
defaultAuth, err = NewUserPassAuthFromEnv()
|
||||||
common.OIDCClientSecret,
|
|
||||||
common.OIDCRedirectURL,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetDefaultAuth() Provider {
|
||||||
|
return defaultAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
func IsEnabled() bool {
|
func IsEnabled() bool {
|
||||||
return common.APIJWTSecret != nil || common.OIDCIssuerURL != ""
|
return common.APIJWTSecret != nil || IsOIDCEnabled()
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthRedirectHandler handles redirect to login page or OIDC login base on configuration.
|
func IsOIDCEnabled() bool {
|
||||||
func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
|
return common.OIDCIssuerURL != ""
|
||||||
switch {
|
|
||||||
case oauthConfig != nil:
|
|
||||||
RedirectOIDC(w, r)
|
|
||||||
return
|
|
||||||
case common.APIJWTSecret != nil:
|
|
||||||
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
|
|
||||||
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
|
|
||||||
claim := &Claims{
|
|
||||||
Username: username,
|
|
||||||
RegisteredClaims: jwt.RegisteredClaims{
|
|
||||||
ExpiresAt: jwt.NewNumericDate(expiresAt),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claim)
|
|
||||||
tokenStr, err := token.SignedString(common.APIJWTSecret)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: CookieToken,
|
|
||||||
Value: tokenStr,
|
|
||||||
Expires: expiresAt,
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
Path: "/",
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// LogoutHandler clear authentication cookie and redirect to login page.
|
|
||||||
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: CookieToken,
|
|
||||||
Value: "",
|
|
||||||
Expires: time.Unix(0, 0),
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
Path: "/",
|
|
||||||
})
|
|
||||||
AuthRedirectHandler(w, r)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||||
if IsEnabled() {
|
if IsEnabled() {
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
if checkToken(w, r) {
|
if err := defaultAuth.CheckToken(r); err != nil {
|
||||||
|
U.RespondError(w, err, http.StatusUnauthorized)
|
||||||
|
} else {
|
||||||
next(w, r)
|
next(w, r)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return next
|
return next
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
|
|
||||||
tokenCookie, err := r.Cookie(CookieToken)
|
|
||||||
if err != nil {
|
|
||||||
U.RespondError(w, E.New("missing token"), http.StatusUnauthorized)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
var claims Claims
|
|
||||||
token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
|
||||||
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
|
||||||
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
|
||||||
}
|
|
||||||
return common.APIJWTSecret, nil
|
|
||||||
})
|
|
||||||
|
|
||||||
switch {
|
|
||||||
case err != nil:
|
|
||||||
break
|
|
||||||
case !token.Valid:
|
|
||||||
err = E.New("invalid token")
|
|
||||||
case claims.Username != common.APIUser:
|
|
||||||
err = E.New("username mismatch").Subject(claims.Username)
|
|
||||||
case claims.ExpiresAt.Before(time.Now()):
|
|
||||||
err = E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
U.RespondError(w, err, http.StatusForbidden)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
const (
|
|
||||||
CookieToken = "godoxy_token"
|
|
||||||
CookieOauthState = "godoxy_oauth_state"
|
|
||||||
)
|
|
|
@ -2,87 +2,186 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"slices"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
CE "github.com/yusing/go-proxy/internal/utils"
|
||||||
"github.com/yusing/go-proxy/internal/utils/strutils"
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type OIDCProvider struct {
|
||||||
oauthConfig *oauth2.Config
|
oauthConfig *oauth2.Config
|
||||||
oidcProvider *oidc.Provider
|
oidcProvider *oidc.Provider
|
||||||
oidcVerifier *oidc.IDTokenVerifier
|
oidcVerifier *oidc.IDTokenVerifier
|
||||||
|
allowedUsers []string
|
||||||
|
allowedGroups []string
|
||||||
|
isMiddleware bool
|
||||||
|
}
|
||||||
|
|
||||||
|
const CookieOauthState = "godoxy_oidc_state"
|
||||||
|
|
||||||
|
const (
|
||||||
|
OIDCMiddlewareCallbackPath = "/auth/callback"
|
||||||
|
OIDCLogoutPath = "/auth/logout"
|
||||||
)
|
)
|
||||||
|
|
||||||
// InitOIDC initializes the OIDC provider.
|
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers, allowedGroups []string) (*OIDCProvider, error) {
|
||||||
func InitOIDC(issuerURL, clientID, clientSecret, redirectURL string) error {
|
if len(allowedUsers)+len(allowedGroups) == 0 {
|
||||||
if issuerURL == "" {
|
return nil, errors.New("OIDC users, groups, or both must not be empty")
|
||||||
return nil // OIDC not configured
|
|
||||||
}
|
}
|
||||||
|
|
||||||
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
provider, err := oidc.NewProvider(context.Background(), issuerURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to initialize OIDC provider: %w", err)
|
return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcProvider = provider
|
return &OIDCProvider{
|
||||||
oidcVerifier = provider.Verifier(&oidc.Config{
|
oauthConfig: &oauth2.Config{
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
})
|
ClientSecret: clientSecret,
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
Endpoint: provider.Endpoint(),
|
||||||
|
Scopes: strutils.CommaSeperatedList(common.OIDCScopes),
|
||||||
|
},
|
||||||
|
oidcProvider: provider,
|
||||||
|
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||||
|
ClientID: clientID,
|
||||||
|
}),
|
||||||
|
allowedUsers: allowedUsers,
|
||||||
|
allowedGroups: allowedGroups,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
oauthConfig = &oauth2.Config{
|
// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables.
|
||||||
ClientID: clientID,
|
func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
|
||||||
ClientSecret: clientSecret,
|
return NewOIDCProvider(
|
||||||
RedirectURL: redirectURL,
|
common.OIDCIssuerURL,
|
||||||
Endpoint: provider.Endpoint(),
|
common.OIDCClientID,
|
||||||
Scopes: strutils.CommaSeperatedList(common.OIDCScopes),
|
common.OIDCClientSecret,
|
||||||
|
common.OIDCRedirectURL,
|
||||||
|
common.OIDCAllowedUsers,
|
||||||
|
common.OIDCAllowedGroups,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) TokenCookieName() string {
|
||||||
|
return "godoxy_oidc_token"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) SetIsMiddleware(enabled bool) {
|
||||||
|
auth.isMiddleware = enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
|
||||||
|
auth.allowedUsers = users
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) SetAllowedGroups(groups []string) {
|
||||||
|
auth.allowedGroups = groups
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) CheckToken(r *http.Request) error {
|
||||||
|
token, err := r.Cookie(auth.TokenCookieName())
|
||||||
|
if err != nil {
|
||||||
|
return ErrMissingToken
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// checks for Expiry, Audience == ClientID, Issuer, etc.
|
||||||
|
idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to verify ID token: %w: %w", ErrInvalidToken, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(idToken.Audience) == 0 {
|
||||||
|
return ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims struct {
|
||||||
|
Email string `json:"email"`
|
||||||
|
Username string `json:"preferred_username"`
|
||||||
|
Groups []string `json:"groups"`
|
||||||
|
}
|
||||||
|
if err := idToken.Claims(&claims); err != nil {
|
||||||
|
return fmt.Errorf("failed to parse claims: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logical AND between allowed users and groups.
|
||||||
|
allowedUser := slices.Contains(auth.allowedUsers, claims.Username)
|
||||||
|
allowedGroup := len(CE.Intersect(claims.Groups, auth.allowedGroups)) > 0
|
||||||
|
if !allowedUser && !allowedGroup {
|
||||||
|
return ErrUserNotAllowed.Subject(claims.Username)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateState generates a random string for OIDC state.
|
||||||
|
const oidcStateLength = 32
|
||||||
|
|
||||||
|
func generateState() (string, error) {
|
||||||
|
b := make([]byte, oidcStateLength)
|
||||||
|
_, err := rand.Read(b)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return base64.URLEncoding.EncodeToString(b)[:oidcStateLength], nil
|
||||||
|
}
|
||||||
|
|
||||||
// RedirectOIDC initiates the OIDC login flow.
|
// RedirectOIDC initiates the OIDC login flow.
|
||||||
func RedirectOIDC(w http.ResponseWriter, r *http.Request) {
|
func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||||
if oauthConfig == nil {
|
state, err := generateState()
|
||||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
if err != nil {
|
||||||
|
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
state := common.GenerateRandomString(32)
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
http.SetCookie(w, &http.Cookie{
|
||||||
Name: CookieOauthState,
|
Name: CookieOauthState,
|
||||||
Value: state,
|
Value: state,
|
||||||
MaxAge: 300,
|
MaxAge: 300,
|
||||||
HttpOnly: true,
|
HttpOnly: true,
|
||||||
SameSite: http.SameSiteNoneMode,
|
SameSite: http.SameSiteLaxMode,
|
||||||
Secure: true,
|
Secure: true,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
})
|
})
|
||||||
|
|
||||||
url := oauthConfig.AuthCodeURL(state)
|
redirURL := auth.oauthConfig.AuthCodeURL(state)
|
||||||
http.Redirect(w, r, url, http.StatusTemporaryRedirect)
|
if auth.isMiddleware {
|
||||||
|
u, err := r.URL.Parse(redirURL)
|
||||||
|
if err != nil {
|
||||||
|
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
q := u.Query()
|
||||||
|
q.Set("redirect_uri", "https://"+r.Host+OIDCMiddlewareCallbackPath+q.Get("redirect_uri"))
|
||||||
|
u.RawQuery = q.Encode()
|
||||||
|
redirURL = u.String()
|
||||||
|
}
|
||||||
|
http.Redirect(w, r, redirURL, http.StatusTemporaryRedirect)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *OIDCProvider) exchange(r *http.Request) (*oauth2.Token, error) {
|
||||||
|
if auth.isMiddleware {
|
||||||
|
cfg := *auth.oauthConfig
|
||||||
|
cfg.RedirectURL = "https://" + r.Host + OIDCMiddlewareCallbackPath
|
||||||
|
return cfg.Exchange(r.Context(), r.URL.Query().Get("code"))
|
||||||
|
}
|
||||||
|
return auth.oauthConfig.Exchange(r.Context(), r.URL.Query().Get("code"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// OIDCCallbackHandler handles the OIDC callback.
|
// OIDCCallbackHandler handles the OIDC callback.
|
||||||
func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
if oauthConfig == nil {
|
|
||||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// For testing purposes, skip provider verification
|
// For testing purposes, skip provider verification
|
||||||
if common.IsTest {
|
if common.IsTest {
|
||||||
handleTestCallback(w, r)
|
auth.handleTestCallback(w, r)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if oidcProvider == nil {
|
|
||||||
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -92,13 +191,13 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.URL.Query().Get("state") != state.Value {
|
query := r.URL.Query()
|
||||||
|
if query.Get("state") != state.Value {
|
||||||
U.HandleErr(w, r, E.New("invalid oauth state"), http.StatusBadRequest)
|
U.HandleErr(w, r, E.New("invalid oauth state"), http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
code := r.URL.Query().Get("code")
|
oauth2Token, err := auth.exchange(r)
|
||||||
oauth2Token, err := oauthConfig.Exchange(r.Context(), code)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
|
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
|
@ -110,32 +209,20 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err := oidcVerifier.Verify(r.Context(), rawIDToken)
|
idToken, err := auth.oidcVerifier.Verify(r.Context(), rawIDToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError)
|
U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var claims struct {
|
setTokenCookie(w, r, auth.TokenCookieName(), rawIDToken, time.Until(idToken.Expiry))
|
||||||
Email string `json:"email"`
|
|
||||||
Username string `json:"preferred_username"`
|
|
||||||
}
|
|
||||||
if err := idToken.Claims(&claims); err != nil {
|
|
||||||
U.HandleErr(w, r, fmt.Errorf("failed to parse claims: %w", err), http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := setAuthenticatedCookie(w, claims.Username); err != nil {
|
|
||||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Redirect to home page
|
// Redirect to home page
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleTestCallback handles OIDC callback in test environment.
|
// handleTestCallback handles OIDC callback in test environment.
|
||||||
func handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
state, err := r.Cookie(CookieOauthState)
|
state, err := r.Cookie(CookieOauthState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
|
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
|
||||||
|
@ -148,10 +235,7 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create test JWT token
|
// Create test JWT token
|
||||||
if err := setAuthenticatedCookie(w, "test-user"); err != nil {
|
setTokenCookie(w, r, auth.TokenCookieName(), "test", time.Hour)
|
||||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,77 +1,165 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/coreos/go-oidc/v3/oidc"
|
"github.com/coreos/go-oidc/v3/oidc"
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
// setupMockOIDC configures mock OIDC provider for testing.
|
// setupMockOIDC configures mock OIDC provider for testing.
|
||||||
func setupMockOIDC(t *testing.T) {
|
func setupMockOIDC(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
oauthConfig = &oauth2.Config{
|
provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO())
|
||||||
ClientID: "test-client",
|
defaultAuth = &OIDCProvider{
|
||||||
ClientSecret: "test-secret",
|
oauthConfig: &oauth2.Config{
|
||||||
RedirectURL: "http://localhost/callback",
|
ClientID: "test-client",
|
||||||
Endpoint: oauth2.Endpoint{
|
ClientSecret: "test-secret",
|
||||||
AuthURL: "http://mock-provider/auth",
|
RedirectURL: "http://localhost/callback",
|
||||||
TokenURL: "http://mock-provider/token",
|
Endpoint: oauth2.Endpoint{
|
||||||
|
AuthURL: "http://mock-provider/auth",
|
||||||
|
TokenURL: "http://mock-provider/token",
|
||||||
|
},
|
||||||
|
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||||
},
|
},
|
||||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
oidcProvider: provider,
|
||||||
|
oidcVerifier: provider.Verifier(&oidc.Config{
|
||||||
|
ClientID: "test-client",
|
||||||
|
}),
|
||||||
|
allowedUsers: []string{"test-user"},
|
||||||
|
allowedGroups: []string{"test-group1", "test-group2"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// discoveryDocument returns a mock OIDC discovery document.
|
||||||
|
func discoveryDocument(t *testing.T, server *httptest.Server) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
discovery := map[string]any{
|
||||||
|
"issuer": server.URL,
|
||||||
|
"authorization_endpoint": server.URL + "/auth",
|
||||||
|
"token_endpoint": server.URL + "/token",
|
||||||
|
}
|
||||||
|
|
||||||
|
return discovery
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
keyID = "test-key-id"
|
||||||
|
clientID = "test-client-id"
|
||||||
|
)
|
||||||
|
|
||||||
|
type provider struct {
|
||||||
|
ts *httptest.Server
|
||||||
|
key *rsa.PrivateKey
|
||||||
|
verifier *oidc.IDTokenVerifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j *provider) SignClaims(t *testing.T, claims jwt.Claims) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
token.Header["kid"] = keyID
|
||||||
|
signed, err := token.SignedString(j.key)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
return signed
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupProvider(t *testing.T) *provider {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Generate an RSA key pair for the test.
|
||||||
|
privKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
|
||||||
|
// Build the matching public JWK that will be served by the endpoint.
|
||||||
|
jwk := buildRSAJWK(t, &privKey.PublicKey, keyID)
|
||||||
|
|
||||||
|
// Start a test server that serves the JWKS endpoint.
|
||||||
|
ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/.well-known/jwks.json":
|
||||||
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"keys": []any{jwk},
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
http.NotFound(w, r)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
t.Cleanup(ts.Close)
|
||||||
|
|
||||||
|
// Create a test OIDCProvider.
|
||||||
|
providerCtx := oidc.ClientContext(context.Background(), ts.Client())
|
||||||
|
keySet := oidc.NewRemoteKeySet(providerCtx, ts.URL+"/.well-known/jwks.json")
|
||||||
|
|
||||||
|
return &provider{
|
||||||
|
ts: ts,
|
||||||
|
key: privKey,
|
||||||
|
verifier: oidc.NewVerifier(ts.URL, keySet, &oidc.Config{
|
||||||
|
ClientID: clientID, // matches audience in the token
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildRSAJWK is a helper to construct a minimal JWK for the JWKS endpoint
|
||||||
|
func buildRSAJWK(t *testing.T, pub *rsa.PublicKey, kid string) map[string]any {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
nBytes := pub.N.Bytes()
|
||||||
|
eBytes := []byte{0x01, 0x00, 0x01} // Usually 65537
|
||||||
|
|
||||||
|
return map[string]any{
|
||||||
|
"kty": "RSA",
|
||||||
|
"alg": "RS256",
|
||||||
|
"use": "sig",
|
||||||
|
"kid": kid,
|
||||||
|
"n": base64.RawURLEncoding.EncodeToString(nBytes),
|
||||||
|
"e": base64.RawURLEncoding.EncodeToString(eBytes),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func cleanup() {
|
func cleanup() {
|
||||||
oauthConfig = nil
|
defaultAuth = nil
|
||||||
oidcProvider = nil
|
|
||||||
oidcVerifier = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOIDCLoginHandler(t *testing.T) {
|
func TestOIDCLoginHandler(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
common.APIJWTSecret = []byte("test-secret")
|
common.APIJWTSecret = []byte("test-secret")
|
||||||
common.IsTest = true
|
t.Cleanup(cleanup)
|
||||||
t.Cleanup(func() {
|
|
||||||
cleanup()
|
|
||||||
common.IsTest = false
|
|
||||||
})
|
|
||||||
setupMockOIDC(t)
|
setupMockOIDC(t)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
configureOAuth bool
|
wantStatus int
|
||||||
wantStatus int
|
wantRedirect bool
|
||||||
wantRedirect bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Success - Redirects to provider",
|
name: "Success - Redirects to provider",
|
||||||
configureOAuth: true,
|
wantStatus: http.StatusTemporaryRedirect,
|
||||||
wantStatus: http.StatusTemporaryRedirect,
|
wantRedirect: true,
|
||||||
wantRedirect: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Failure - OIDC not configured",
|
|
||||||
configureOAuth: false,
|
|
||||||
wantStatus: http.StatusNotImplemented,
|
|
||||||
wantRedirect: false,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if !tt.configureOAuth {
|
|
||||||
oauthConfig = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
|
req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
RedirectOIDC(w, req)
|
defaultAuth.RedirectLoginPage(w, req)
|
||||||
|
|
||||||
if got := w.Code; got != tt.wantStatus {
|
if got := w.Code; got != tt.wantStatus {
|
||||||
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
|
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
|
||||||
|
@ -94,110 +182,267 @@ func TestOIDCLoginHandler(t *testing.T) {
|
||||||
func TestOIDCCallbackHandler(t *testing.T) {
|
func TestOIDCCallbackHandler(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
common.APIJWTSecret = []byte("test-secret")
|
common.APIJWTSecret = []byte("test-secret")
|
||||||
common.IsTest = true
|
t.Cleanup(cleanup)
|
||||||
t.Cleanup(func() {
|
|
||||||
cleanup()
|
|
||||||
common.IsTest = false
|
|
||||||
})
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
configureOAuth bool
|
state string
|
||||||
state string
|
code string
|
||||||
code string
|
setupMocks bool
|
||||||
setupMocks func()
|
wantStatus int
|
||||||
wantStatus int
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Success - Valid callback",
|
name: "Success - Valid callback",
|
||||||
configureOAuth: true,
|
state: "valid-state",
|
||||||
state: "valid-state",
|
code: "valid-code",
|
||||||
code: "valid-code",
|
setupMocks: true,
|
||||||
setupMocks: func() {
|
|
||||||
setupMockOIDC(t)
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusTemporaryRedirect,
|
wantStatus: http.StatusTemporaryRedirect,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "Failure - OIDC not configured",
|
name: "Failure - Missing state",
|
||||||
configureOAuth: false,
|
code: "valid-code",
|
||||||
wantStatus: http.StatusNotImplemented,
|
setupMocks: true,
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "Failure - Missing state",
|
|
||||||
configureOAuth: true,
|
|
||||||
code: "valid-code",
|
|
||||||
setupMocks: func() {
|
|
||||||
setupMockOIDC(t)
|
|
||||||
},
|
|
||||||
wantStatus: http.StatusBadRequest,
|
wantStatus: http.StatusBadRequest,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if tt.setupMocks != nil {
|
if tt.setupMocks {
|
||||||
tt.setupMocks()
|
setupMockOIDC(t)
|
||||||
}
|
|
||||||
|
|
||||||
if !tt.configureOAuth {
|
|
||||||
oauthConfig = nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil)
|
req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil)
|
||||||
if tt.state != "" {
|
if tt.state != "" {
|
||||||
req.AddCookie(&http.Cookie{
|
req.AddCookie(&http.Cookie{
|
||||||
Name: "oauth_state",
|
Name: CookieOauthState,
|
||||||
Value: tt.state,
|
Value: tt.state,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
OIDCCallbackHandler(w, req)
|
defaultAuth.LoginCallbackHandler(w, req)
|
||||||
|
|
||||||
if got := w.Code; got != tt.wantStatus {
|
if got := w.Code; got != tt.wantStatus {
|
||||||
t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus)
|
t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus)
|
||||||
}
|
}
|
||||||
|
|
||||||
if tt.wantStatus == http.StatusTemporaryRedirect {
|
if tt.wantStatus == http.StatusTemporaryRedirect {
|
||||||
cookie := w.Header().Get("Set-Cookie")
|
setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
||||||
if cookie == "" {
|
ExpectEqual(t, setCookie.Name, defaultAuth.TokenCookieName())
|
||||||
t.Error("OIDCCallbackHandler() missing token cookie")
|
ExpectTrue(t, setCookie.Value != "")
|
||||||
}
|
ExpectEqual(t, setCookie.Path, "/")
|
||||||
|
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
||||||
|
ExpectEqual(t, setCookie.HttpOnly, true)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInitOIDC(t *testing.T) {
|
func TestInitOIDC(t *testing.T) {
|
||||||
common.IsTest = true
|
setupMockOIDC(t)
|
||||||
t.Cleanup(func() {
|
// Create a test server that serves the discovery document
|
||||||
common.IsTest = false
|
var server *httptest.Server
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
ExpectNoError(t, json.NewEncoder(w).Encode(discoveryDocument(t, server)))
|
||||||
})
|
})
|
||||||
|
server = httptest.NewServer(mux)
|
||||||
|
t.Cleanup(server.Close)
|
||||||
|
t.Cleanup(cleanup)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
issuerURL string
|
issuerURL string
|
||||||
clientID string
|
clientID string
|
||||||
clientSecret string
|
clientSecret string
|
||||||
redirectURL string
|
redirectURL string
|
||||||
wantErr bool
|
allowedUsers []string
|
||||||
|
allowedGroups []string
|
||||||
|
wantErr bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "Success - Empty configuration",
|
name: "Fail - Empty configuration",
|
||||||
issuerURL: "",
|
issuerURL: "",
|
||||||
clientID: "",
|
clientID: "",
|
||||||
clientSecret: "",
|
clientSecret: "",
|
||||||
redirectURL: "",
|
redirectURL: "",
|
||||||
|
allowedUsers: nil,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Valid configuration with users",
|
||||||
|
issuerURL: server.URL,
|
||||||
|
clientID: "client_id",
|
||||||
|
clientSecret: "client_secret",
|
||||||
|
redirectURL: "https://example.com/callback",
|
||||||
|
allowedUsers: []string{"user1", "user2"},
|
||||||
wantErr: false,
|
wantErr: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Valid configuration with groups",
|
||||||
|
issuerURL: server.URL,
|
||||||
|
clientID: "client_id",
|
||||||
|
clientSecret: "client_secret",
|
||||||
|
redirectURL: "https://example.com/callback",
|
||||||
|
allowedGroups: []string{"group1", "group2"},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Fail - No allowed users or allowed groups",
|
||||||
|
issuerURL: "https://example.com",
|
||||||
|
clientID: "client_id",
|
||||||
|
clientSecret: "client_secret",
|
||||||
|
redirectURL: "https://example.com/callback",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Cleanup(cleanup)
|
_, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers, tt.allowedGroups)
|
||||||
err := InitOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL)
|
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCheckToken(t *testing.T) {
|
||||||
|
provider := setupProvider(t)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
allowedUsers []string
|
||||||
|
allowedGroups []string
|
||||||
|
claims jwt.Claims
|
||||||
|
wantErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Success - Valid token with allowed user",
|
||||||
|
allowedUsers: []string{"user1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Valid token with allowed group",
|
||||||
|
allowedGroups: []string{"group1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Server omits groups, but user is allowed",
|
||||||
|
allowedUsers: []string{"user1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Server omits preferred_username, but group is allowed",
|
||||||
|
allowedGroups: []string{"group1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Success - Valid token with allowed user and group",
|
||||||
|
allowedUsers: []string{"user1"},
|
||||||
|
allowedGroups: []string{"group1"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error - User not allowed",
|
||||||
|
allowedUsers: []string{"user2", "user3"},
|
||||||
|
allowedGroups: []string{"group2", "group3"},
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
wantErr: ErrUserNotAllowed,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error - Server returns incorrect issuer",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": "https://example.com",
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
wantErr: ErrInvalidToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error - Server returns incorrect audience",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": "some-other-audience",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
wantErr: ErrInvalidToken,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Error - Server returns expired token",
|
||||||
|
claims: jwt.MapClaims{
|
||||||
|
"iss": provider.ts.URL,
|
||||||
|
"aud": clientID,
|
||||||
|
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||||
|
"preferred_username": "user1",
|
||||||
|
"groups": []string{"group1"},
|
||||||
|
},
|
||||||
|
wantErr: ErrInvalidToken,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Create the Auth Provider.
|
||||||
|
auth := &OIDCProvider{
|
||||||
|
oidcVerifier: provider.verifier,
|
||||||
|
allowedUsers: tc.allowedUsers,
|
||||||
|
allowedGroups: tc.allowedGroups,
|
||||||
|
}
|
||||||
|
// Sign the claims to create a token.
|
||||||
|
signedToken := provider.SignClaims(t, tc.claims)
|
||||||
|
// Craft a test HTTP request that includes the token as a cookie.
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
req.AddCookie(&http.Cookie{
|
||||||
|
Name: auth.TokenCookieName(),
|
||||||
|
Value: signedToken,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call CheckToken and verify the result.
|
||||||
|
err := auth.CheckToken(req)
|
||||||
|
if tc.wantErr == nil {
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
} else {
|
||||||
|
ExpectError(t, tc.wantErr, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
12
internal/api/v1/auth/provider.go
Normal file
12
internal/api/v1/auth/provider.go
Normal file
|
@ -0,0 +1,12 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Provider interface {
|
||||||
|
TokenCookieName() string
|
||||||
|
CheckToken(r *http.Request) error
|
||||||
|
RedirectLoginPage(w http.ResponseWriter, r *http.Request)
|
||||||
|
LoginCallbackHandler(w http.ResponseWriter, r *http.Request)
|
||||||
|
}
|
|
@ -1,13 +1,17 @@
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
U "github.com/yusing/go-proxy/internal/api/v1/utils"
|
||||||
"github.com/yusing/go-proxy/internal/common"
|
"github.com/yusing/go-proxy/internal/common"
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -15,31 +19,120 @@ var (
|
||||||
ErrInvalidPassword = E.New("invalid password")
|
ErrInvalidPassword = E.New("invalid password")
|
||||||
)
|
)
|
||||||
|
|
||||||
func validatePassword(cred *Credentials) error {
|
type (
|
||||||
if cred.Username != common.APIUser {
|
UserPassAuth struct {
|
||||||
return ErrInvalidUsername.Subject(cred.Username)
|
username string
|
||||||
|
pwdHash []byte
|
||||||
|
secret []byte
|
||||||
|
tokenTTL time.Duration
|
||||||
}
|
}
|
||||||
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) {
|
UserPassClaims struct {
|
||||||
return ErrInvalidPassword.Subject(cred.Password)
|
Username string `json:"username"`
|
||||||
|
jwt.RegisteredClaims
|
||||||
}
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewUserPassAuth(username, password string, secret []byte, tokenTTL time.Duration) (*UserPassAuth, error) {
|
||||||
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &UserPassAuth{
|
||||||
|
username: username,
|
||||||
|
pwdHash: hash,
|
||||||
|
secret: secret,
|
||||||
|
tokenTTL: tokenTTL,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUserPassAuthFromEnv() (*UserPassAuth, error) {
|
||||||
|
return NewUserPassAuth(
|
||||||
|
common.APIUser,
|
||||||
|
common.APIPassword,
|
||||||
|
common.APIJWTSecret,
|
||||||
|
common.APIJWTTokenTTL,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *UserPassAuth) TokenCookieName() string {
|
||||||
|
return "godoxy_token"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *UserPassAuth) NewToken() (token string, err error) {
|
||||||
|
claim := &UserPassClaims{
|
||||||
|
Username: auth.username,
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
ExpiresAt: jwt.NewNumericDate(time.Now().Add(auth.tokenTTL)),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
tok := jwt.NewWithClaims(jwt.SigningMethodHS512, claim)
|
||||||
|
token, err = tok.SignedString(auth.secret)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (auth *UserPassAuth) CheckToken(r *http.Request) error {
|
||||||
|
jwtCookie, err := r.Cookie(auth.TokenCookieName())
|
||||||
|
if err != nil {
|
||||||
|
return ErrMissingToken
|
||||||
|
}
|
||||||
|
var claims UserPassClaims
|
||||||
|
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
|
||||||
|
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
|
||||||
|
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
|
||||||
|
}
|
||||||
|
return auth.secret, nil
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch {
|
||||||
|
case !token.Valid:
|
||||||
|
return ErrInvalidToken
|
||||||
|
case claims.Username != auth.username:
|
||||||
|
return ErrUserNotAllowed.Subject(claims.Username)
|
||||||
|
case claims.ExpiresAt.Before(time.Now()):
|
||||||
|
return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserPassLoginHandler handles user login.
|
func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||||
func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) {
|
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
|
||||||
var creds Credentials
|
}
|
||||||
|
|
||||||
|
func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var creds struct {
|
||||||
|
User string `json:"username"`
|
||||||
|
Pass string `json:"password"`
|
||||||
|
}
|
||||||
err := json.NewDecoder(r.Body).Decode(&creds)
|
err := json.NewDecoder(r.Body).Decode(&creds)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
U.HandleErr(w, r, err, http.StatusBadRequest)
|
U.HandleErr(w, r, err, http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := validatePassword(&creds); err != nil {
|
if err := auth.validatePassword(creds.User, creds.Pass); err != nil {
|
||||||
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
U.HandleErr(w, r, err, http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := setAuthenticatedCookie(w, creds.Username); err != nil {
|
token, err := auth.NewToken()
|
||||||
|
if err != nil {
|
||||||
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
U.HandleErr(w, r, err, http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
setTokenCookie(w, r, auth.TokenCookieName(), token, auth.tokenTTL)
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (auth *UserPassAuth) validatePassword(user, pass string) error {
|
||||||
|
if user != auth.username {
|
||||||
|
return ErrInvalidUsername.Subject(user)
|
||||||
|
}
|
||||||
|
if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil {
|
||||||
|
return ErrInvalidPassword.With(err).Subject(pass)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
116
internal/api/v1/auth/userpass_test.go
Normal file
116
internal/api/v1/auth/userpass_test.go
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
"golang.org/x/crypto/bcrypt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newMockUserPassAuth() *UserPassAuth {
|
||||||
|
return &UserPassAuth{
|
||||||
|
username: "username",
|
||||||
|
pwdHash: E.Must(bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)),
|
||||||
|
secret: []byte("abcdefghijklmnopqrstuvwxyz"),
|
||||||
|
tokenTTL: time.Hour,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserPassValidateCredentials(t *testing.T) {
|
||||||
|
auth := newMockUserPassAuth()
|
||||||
|
err := auth.validatePassword("username", "password")
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
err = auth.validatePassword("username", "wrong-password")
|
||||||
|
ExpectError(t, ErrInvalidPassword, err)
|
||||||
|
err = auth.validatePassword("wrong-username", "password")
|
||||||
|
ExpectError(t, ErrInvalidUsername, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserPassCheckToken(t *testing.T) {
|
||||||
|
auth := newMockUserPassAuth()
|
||||||
|
token, err := auth.NewToken()
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
tests := []struct {
|
||||||
|
token string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
token: token,
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token: "invalid-token",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
token: "",
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
req := &http.Request{Header: http.Header{}}
|
||||||
|
if tt.token != "" {
|
||||||
|
req.Header.Set("Cookie", auth.TokenCookieName()+"="+tt.token)
|
||||||
|
}
|
||||||
|
err = auth.CheckToken(req)
|
||||||
|
if tt.wantErr {
|
||||||
|
ExpectTrue(t, err != nil)
|
||||||
|
} else {
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUserPassLoginCallbackHandler(t *testing.T) {
|
||||||
|
type cred struct {
|
||||||
|
User string `json:"username"`
|
||||||
|
Pass string `json:"password"`
|
||||||
|
}
|
||||||
|
auth := newMockUserPassAuth()
|
||||||
|
tests := []struct {
|
||||||
|
creds cred
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
creds: cred{
|
||||||
|
User: "username",
|
||||||
|
Pass: "password",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
creds: cred{
|
||||||
|
User: "username",
|
||||||
|
Pass: "wrong-password",
|
||||||
|
},
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
req := &http.Request{
|
||||||
|
Host: "app.example.com",
|
||||||
|
Body: io.NopCloser(bytes.NewReader(E.Must(json.Marshal(tt.creds)))),
|
||||||
|
}
|
||||||
|
auth.LoginCallbackHandler(w, req)
|
||||||
|
if tt.wantErr {
|
||||||
|
ExpectEqual(t, w.Code, http.StatusUnauthorized)
|
||||||
|
} else {
|
||||||
|
setCookie := E.Must(http.ParseSetCookie(w.Header().Get("Set-Cookie")))
|
||||||
|
ExpectTrue(t, setCookie.Name == auth.TokenCookieName())
|
||||||
|
ExpectTrue(t, setCookie.Value != "")
|
||||||
|
ExpectEqual(t, setCookie.Domain, "example.com")
|
||||||
|
ExpectEqual(t, setCookie.Path, "/")
|
||||||
|
ExpectEqual(t, setCookie.SameSite, http.SameSiteLaxMode)
|
||||||
|
ExpectEqual(t, setCookie.HttpOnly, true)
|
||||||
|
ExpectEqual(t, w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
70
internal/api/v1/auth/utils.go
Normal file
70
internal/api/v1/auth/utils.go
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrMissingToken = E.New("missing token")
|
||||||
|
ErrInvalidToken = E.New("invalid token")
|
||||||
|
ErrUserNotAllowed = E.New("user not allowed")
|
||||||
|
)
|
||||||
|
|
||||||
|
// cookieFQDN returns the fully qualified domain name of the request host
|
||||||
|
// with subdomain stripped.
|
||||||
|
//
|
||||||
|
// If the request host does not have a subdomain,
|
||||||
|
// an empty string is returned
|
||||||
|
//
|
||||||
|
// "abc.example.com" -> "example.com"
|
||||||
|
// "example.com" -> ""
|
||||||
|
func cookieFQDN(r *http.Request) string {
|
||||||
|
host, _, err := net.SplitHostPort(r.Host)
|
||||||
|
if err != nil {
|
||||||
|
host = r.Host
|
||||||
|
}
|
||||||
|
parts := strutils.SplitRune(host, '.')
|
||||||
|
if len(parts) < 2 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parts[0] = ""
|
||||||
|
return strutils.JoinRune(parts, '.')
|
||||||
|
}
|
||||||
|
|
||||||
|
func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string, ttl time.Duration) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: value,
|
||||||
|
MaxAge: int(ttl.Seconds()),
|
||||||
|
Domain: cookieFQDN(r),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: name,
|
||||||
|
Value: "",
|
||||||
|
MaxAge: -1,
|
||||||
|
Domain: cookieFQDN(r),
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
Path: "/",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogoutCallbackHandler(auth Provider) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
clearTokenCookie(w, r, auth.TokenCookieName())
|
||||||
|
auth.RedirectLoginPage(w, r)
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
|
||||||
return logging.WithLevel(level).
|
return logging.WithLevel(level).
|
||||||
Str("module", "api").
|
Str("module", "api").
|
||||||
Str("remote", r.RemoteAddr).
|
Str("remote", r.RemoteAddr).
|
||||||
|
Str("host", r.Host).
|
||||||
Str("uri", r.Method+" "+r.RequestURI)
|
Str("uri", r.Method+" "+r.RequestURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,18 +1,11 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha512"
|
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
func HashPassword(pwd string) []byte {
|
|
||||||
h := sha512.New()
|
|
||||||
h.Write([]byte(pwd))
|
|
||||||
return h.Sum(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeJWTKey(key string) []byte {
|
func decodeJWTKey(key string) []byte {
|
||||||
if key == "" {
|
if key == "" {
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -9,6 +9,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
|
"github.com/yusing/go-proxy/internal/utils/strutils"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -43,17 +44,19 @@ var (
|
||||||
MetricsHTTPURL = GetAddrEnv("PROMETHEUS_ADDR", "", "http")
|
MetricsHTTPURL = GetAddrEnv("PROMETHEUS_ADDR", "", "http")
|
||||||
PrometheusEnabled = MetricsHTTPURL != ""
|
PrometheusEnabled = MetricsHTTPURL != ""
|
||||||
|
|
||||||
APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", ""))
|
APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", ""))
|
||||||
APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour)
|
APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour)
|
||||||
APIUser = GetEnvString("API_USER", "admin")
|
APIUser = GetEnvString("API_USER", "admin")
|
||||||
APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password"))
|
APIPassword = GetEnvString("API_PASSWORD", "password")
|
||||||
|
|
||||||
// OIDC Configuration.
|
// OIDC Configuration.
|
||||||
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
|
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
|
||||||
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
|
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
|
||||||
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
|
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
|
||||||
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
|
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
|
||||||
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
|
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
|
||||||
|
OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "")
|
||||||
|
OIDCAllowedGroups = GetCommaSepEnv("OIDC_ALLOWED_GROUPS", "")
|
||||||
)
|
)
|
||||||
|
|
||||||
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {
|
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {
|
||||||
|
@ -105,3 +108,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
|
||||||
func GetDurationEnv(key string, defaultValue time.Duration) time.Duration {
|
func GetDurationEnv(key string, defaultValue time.Duration) time.Duration {
|
||||||
return GetEnv(key, defaultValue, time.ParseDuration)
|
return GetEnv(key, defaultValue, time.ParseDuration)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func GetCommaSepEnv(key string, defaultValue string) []string {
|
||||||
|
return strutils.CommaSeperatedList(GetEnvString(key, defaultValue))
|
||||||
|
}
|
||||||
|
|
|
@ -1,13 +0,0 @@
|
||||||
package common
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/rand"
|
|
||||||
"encoding/base64"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GenerateRandomString generates a random string of specified length.
|
|
||||||
func GenerateRandomString(length int) string {
|
|
||||||
b := make([]byte, length)
|
|
||||||
rand.Read(b)
|
|
||||||
return base64.URLEncoding.EncodeToString(b)[:length]
|
|
||||||
}
|
|
|
@ -89,7 +89,11 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
// Then scraper / scanners will know the subdomain is invalid.
|
// Then scraper / scanners will know the subdomain is invalid.
|
||||||
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
|
||||||
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
if served := middleware.ServeStaticErrorPageFile(w, r); !served {
|
||||||
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request")
|
logger.Err(err).
|
||||||
|
Str("method", r.Method).
|
||||||
|
Str("url", r.URL.String()).
|
||||||
|
Str("remote", r.RemoteAddr).
|
||||||
|
Msg("request")
|
||||||
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
|
||||||
if ok {
|
if ok {
|
||||||
w.WriteHeader(http.StatusNotFound)
|
w.WriteHeader(http.StatusNotFound)
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net/http"
|
"net/http"
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
E "github.com/yusing/go-proxy/internal/error"
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
@ -26,28 +27,50 @@ type (
|
||||||
name string
|
name string
|
||||||
construct ImplNewFunc
|
construct ImplNewFunc
|
||||||
impl any
|
impl any
|
||||||
|
// priority is only applied for ReverseProxy.
|
||||||
|
//
|
||||||
|
// Middleware compose follows the order of the slice
|
||||||
|
//
|
||||||
|
// Default is 10, 0 is the highest
|
||||||
|
priority int
|
||||||
}
|
}
|
||||||
|
ByPriority []*Middleware
|
||||||
|
|
||||||
RequestModifier interface {
|
RequestModifier interface {
|
||||||
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
before(w http.ResponseWriter, r *http.Request) (proceed bool)
|
||||||
}
|
}
|
||||||
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
ResponseModifier interface{ modifyResponse(r *http.Response) error }
|
||||||
MiddlewareWithSetup interface{ setup() }
|
MiddlewareWithSetup interface{ setup() }
|
||||||
MiddlewareFinalizer interface{ finalize() }
|
MiddlewareFinalizer interface{ finalize() }
|
||||||
|
MiddlewareFinalizerWithError interface {
|
||||||
|
finalize() error
|
||||||
|
}
|
||||||
MiddlewareWithTracer interface {
|
MiddlewareWithTracer interface {
|
||||||
enableTrace()
|
enableTrace()
|
||||||
getTracer() *Tracer
|
getTracer() *Tracer
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const DefaultPriority = 10
|
||||||
|
|
||||||
|
func (m ByPriority) Len() int { return len(m) }
|
||||||
|
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
|
||||||
|
func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority }
|
||||||
|
|
||||||
func NewMiddleware[ImplType any]() *Middleware {
|
func NewMiddleware[ImplType any]() *Middleware {
|
||||||
// type check
|
// type check
|
||||||
switch any(new(ImplType)).(type) {
|
t := any(new(ImplType))
|
||||||
|
switch t.(type) {
|
||||||
case RequestModifier:
|
case RequestModifier:
|
||||||
case ResponseModifier:
|
case ResponseModifier:
|
||||||
default:
|
default:
|
||||||
panic("must implement RequestModifier or ResponseModifier")
|
panic("must implement RequestModifier or ResponseModifier")
|
||||||
}
|
}
|
||||||
|
_, hasFinializer := t.(MiddlewareFinalizer)
|
||||||
|
_, hasFinializerWithError := t.(MiddlewareFinalizerWithError)
|
||||||
|
if hasFinializer && hasFinializerWithError {
|
||||||
|
panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive")
|
||||||
|
}
|
||||||
return &Middleware{
|
return &Middleware{
|
||||||
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
|
||||||
construct: func() any { return new(ImplType) },
|
construct: func() any { return new(ImplType) },
|
||||||
|
@ -84,13 +107,29 @@ func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
|
||||||
if len(optsRaw) == 0 {
|
if len(optsRaw) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
priority, ok := optsRaw["priority"].(int)
|
||||||
|
if ok {
|
||||||
|
m.priority = priority
|
||||||
|
// remove priority for deserialization, restore later
|
||||||
|
delete(optsRaw, "priority")
|
||||||
|
defer func() {
|
||||||
|
optsRaw["priority"] = priority
|
||||||
|
}()
|
||||||
|
} else {
|
||||||
|
m.priority = DefaultPriority
|
||||||
|
}
|
||||||
return utils.Deserialize(optsRaw, m.impl)
|
return utils.Deserialize(optsRaw, m.impl)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) finalize() {
|
func (m *Middleware) finalize() error {
|
||||||
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
|
||||||
finalizer.finalize()
|
finalizer.finalize()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok {
|
||||||
|
return finalizer.finalize()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
|
@ -105,7 +144,9 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
|
||||||
if err := mid.apply(optsRaw); err != nil {
|
if err := mid.apply(optsRaw); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mid.finalize()
|
if err := mid.finalize(); err != nil {
|
||||||
|
return nil, E.From(err)
|
||||||
|
}
|
||||||
return mid, nil
|
return mid, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,8 +160,9 @@ func (m *Middleware) String() string {
|
||||||
|
|
||||||
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
func (m *Middleware) MarshalJSON() ([]byte, error) {
|
||||||
return json.MarshalIndent(map[string]any{
|
return json.MarshalIndent(map[string]any{
|
||||||
"name": m.name,
|
"name": m.name,
|
||||||
"options": m.impl,
|
"options": m.impl,
|
||||||
|
"priority": m.priority,
|
||||||
}, "", " ")
|
}, "", " ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,6 +235,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
|
||||||
}
|
}
|
||||||
|
|
||||||
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
|
||||||
|
sort.Sort(ByPriority(middlewares))
|
||||||
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
|
||||||
|
|
||||||
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
mid := NewMiddlewareChain(rp.TargetName, middlewares)
|
||||||
|
|
37
internal/net/http/middleware/middleware_test.go
Normal file
37
internal/net/http/middleware/middleware_test.go
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
. "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testPriority struct {
|
||||||
|
Value int `json:"value"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var test = NewMiddleware[testPriority]()
|
||||||
|
|
||||||
|
func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
w.Header().Add("Test-Value", strconv.Itoa(t.Value))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMiddlewarePriority(t *testing.T) {
|
||||||
|
priorities := []int{4, 7, 9, 0}
|
||||||
|
chain := make([]*Middleware, len(priorities))
|
||||||
|
for i, p := range priorities {
|
||||||
|
mid, err := test.New(OptionsRaw{
|
||||||
|
"priority": p,
|
||||||
|
"value": i,
|
||||||
|
})
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
chain[i] = mid
|
||||||
|
}
|
||||||
|
res, err := newMiddlewaresTest(chain, nil)
|
||||||
|
ExpectNoError(t, err)
|
||||||
|
ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
|
||||||
|
}
|
|
@ -14,6 +14,8 @@ import (
|
||||||
var allMiddlewares = map[string]*Middleware{
|
var allMiddlewares = map[string]*Middleware{
|
||||||
"redirecthttp": RedirectHTTP,
|
"redirecthttp": RedirectHTTP,
|
||||||
|
|
||||||
|
"oidc": OIDC,
|
||||||
|
|
||||||
"request": ModifyRequest,
|
"request": ModifyRequest,
|
||||||
"modifyrequest": ModifyRequest,
|
"modifyrequest": ModifyRequest,
|
||||||
"response": ModifyResponse,
|
"response": ModifyResponse,
|
||||||
|
|
59
internal/net/http/middleware/oidc.go
Normal file
59
internal/net/http/middleware/oidc.go
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/yusing/go-proxy/internal/api/v1/auth"
|
||||||
|
E "github.com/yusing/go-proxy/internal/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
type oidcMiddleware struct {
|
||||||
|
AllowedUsers []string `json:"allowed_users"`
|
||||||
|
AllowedGroups []string `json:"allowed_groups"`
|
||||||
|
|
||||||
|
auth auth.Provider
|
||||||
|
authMux *http.ServeMux
|
||||||
|
logoutHandler http.HandlerFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
var OIDC = NewMiddleware[oidcMiddleware]()
|
||||||
|
|
||||||
|
func (amw *oidcMiddleware) finalize() error {
|
||||||
|
if !auth.IsOIDCEnabled() {
|
||||||
|
return E.New("OIDC not enabled but ODIC middleware is used")
|
||||||
|
}
|
||||||
|
authProvider, err := auth.NewOIDCProviderFromEnv()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
authProvider.SetIsMiddleware(true)
|
||||||
|
if len(amw.AllowedUsers) > 0 {
|
||||||
|
authProvider.SetAllowedUsers(amw.AllowedUsers)
|
||||||
|
}
|
||||||
|
if len(amw.AllowedGroups) > 0 {
|
||||||
|
authProvider.SetAllowedGroups(amw.AllowedGroups)
|
||||||
|
}
|
||||||
|
|
||||||
|
amw.authMux = http.NewServeMux()
|
||||||
|
amw.authMux.HandleFunc(auth.OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
|
||||||
|
amw.authMux.HandleFunc(auth.OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||||
|
})
|
||||||
|
amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
|
||||||
|
amw.logoutHandler = auth.LogoutCallbackHandler(authProvider)
|
||||||
|
amw.auth = authProvider
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
|
||||||
|
if err := amw.auth.CheckToken(r); err != nil {
|
||||||
|
amw.authMux.ServeHTTP(w, r)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if r.URL.Path == auth.OIDCLogoutPath {
|
||||||
|
amw.logoutHandler(w, r)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
|
@ -127,6 +127,20 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
|
||||||
}
|
}
|
||||||
args.setDefaults()
|
args.setDefaults()
|
||||||
|
|
||||||
|
mid, setOptErr := middleware.New(args.middlewareOpt)
|
||||||
|
if setOptErr != nil {
|
||||||
|
return nil, setOptErr
|
||||||
|
}
|
||||||
|
|
||||||
|
return newMiddlewaresTest([]*Middleware{mid}, args)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) {
|
||||||
|
if args == nil {
|
||||||
|
args = new(testArgs)
|
||||||
|
}
|
||||||
|
args.setDefaults()
|
||||||
|
|
||||||
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
|
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
|
||||||
for k, v := range args.headers {
|
for k, v := range args.headers {
|
||||||
req.Header[k] = v
|
req.Header[k] = v
|
||||||
|
@ -139,14 +153,8 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
|
||||||
rr.parent = http.DefaultTransport
|
rr.parent = http.DefaultTransport
|
||||||
}
|
}
|
||||||
|
|
||||||
rp := reverseproxy.NewReverseProxy(middleware.name, args.upstreamURL, rr)
|
rp := reverseproxy.NewReverseProxy("test", args.upstreamURL, rr)
|
||||||
|
patchReverseProxy(rp, middlewares)
|
||||||
mid, setOptErr := middleware.New(args.middlewareOpt)
|
|
||||||
if setOptErr != nil {
|
|
||||||
return nil, setOptErr
|
|
||||||
}
|
|
||||||
|
|
||||||
patchReverseProxy(rp, []*Middleware{mid})
|
|
||||||
rp.ServeHTTP(w, req)
|
rp.ServeHTTP(w, req)
|
||||||
|
|
||||||
resp := w.Result()
|
resp := w.Result()
|
||||||
|
|
20
internal/utils/slices.go
Normal file
20
internal/utils/slices.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
// Intersect returns a new slice containing the elements that are present in both input slices.
|
||||||
|
// This provides a more efficient solution than using two nested loops.
|
||||||
|
func Intersect[T comparable, Slice ~[]T](slice1 Slice, slice2 Slice) Slice {
|
||||||
|
var result Slice
|
||||||
|
seen := map[T]struct{}{}
|
||||||
|
|
||||||
|
for i := range slice1 {
|
||||||
|
seen[slice1[i]] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range slice2 {
|
||||||
|
if _, ok := seen[slice2[i]]; ok {
|
||||||
|
result = append(result, slice2[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
96
internal/utils/slices_test.go
Normal file
96
internal/utils/slices_test.go
Normal file
|
@ -0,0 +1,96 @@
|
||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
utils "github.com/yusing/go-proxy/internal/utils/testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestIntersect(t *testing.T) {
|
||||||
|
t.Run("strings", func(t *testing.T) {
|
||||||
|
t.Run("no intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []string{"a", "b", "c"}
|
||||||
|
slice2 = []string{"d", "e", "f"}
|
||||||
|
want []string
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.Sort(result)
|
||||||
|
slices.Sort(want)
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
t.Run("intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []string{"a", "b", "c"}
|
||||||
|
slice2 = []string{"b", "c", "d"}
|
||||||
|
want = []string{"b", "c"}
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.Sort(result)
|
||||||
|
slices.Sort(want)
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("ints", func(t *testing.T) {
|
||||||
|
t.Run("no intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []int{1, 2, 3}
|
||||||
|
slice2 = []int{4, 5, 6}
|
||||||
|
want []int
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.Sort(result)
|
||||||
|
slices.Sort(want)
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
t.Run("intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []int{1, 2, 3}
|
||||||
|
slice2 = []int{2, 3, 4}
|
||||||
|
want = []int{2, 3}
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.Sort(result)
|
||||||
|
slices.Sort(want)
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
t.Run("complex", func(t *testing.T) {
|
||||||
|
type T struct {
|
||||||
|
A string
|
||||||
|
B int
|
||||||
|
}
|
||||||
|
t.Run("no intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}}
|
||||||
|
slice2 = []T{{"d", 4}, {"e", 5}, {"f", 6}}
|
||||||
|
want []T
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.SortFunc(result, func(i T, j T) int {
|
||||||
|
return strings.Compare(i.A, j.A)
|
||||||
|
})
|
||||||
|
slices.SortFunc(want, func(i T, j T) int {
|
||||||
|
return strings.Compare(i.A, j.A)
|
||||||
|
})
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
t.Run("intersection", func(t *testing.T) {
|
||||||
|
var (
|
||||||
|
slice1 = []T{{"a", 1}, {"b", 2}, {"c", 3}}
|
||||||
|
slice2 = []T{{"b", 2}, {"c", 3}, {"d", 4}}
|
||||||
|
want = []T{{"b", 2}, {"c", 3}}
|
||||||
|
)
|
||||||
|
result := Intersect(slice1, slice2)
|
||||||
|
slices.SortFunc(result, func(i T, j T) int {
|
||||||
|
return strings.Compare(i.A, j.A)
|
||||||
|
})
|
||||||
|
slices.SortFunc(want, func(i T, j T) int {
|
||||||
|
return strings.Compare(i.A, j.A)
|
||||||
|
})
|
||||||
|
utils.ExpectDeepEqual(t, result, want)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
|
@ -10,6 +10,9 @@ import (
|
||||||
// CommaSeperatedList returns a list of strings split by commas,
|
// CommaSeperatedList returns a list of strings split by commas,
|
||||||
// then trim spaces from each element.
|
// then trim spaces from each element.
|
||||||
func CommaSeperatedList(s string) []string {
|
func CommaSeperatedList(s string) []string {
|
||||||
|
if s == "" {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
res := SplitComma(s)
|
res := SplitComma(s)
|
||||||
for i, part := range res {
|
for i, part := range res {
|
||||||
res[i] = strings.TrimSpace(part)
|
res[i] = strings.TrimSpace(part)
|
||||||
|
|
|
@ -73,6 +73,26 @@ GoDoxy v0.8.2 expected changes
|
||||||
* Connection #0 to host localhost left intact
|
* Connection #0 to host localhost left intact
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- **Thanks [polds](https://github.com/polds)**
|
||||||
|
Support WebUI authentication via OIDC by setting these environment variables:
|
||||||
|
- `GODOXY_OIDC_ISSUER_URL` e.g.:
|
||||||
|
- Pocket ID: `https://pocker-id.yourdomain.com`
|
||||||
|
- Authentik: `https://authentik.yourdomain.com/application/o/<application_slug>/` **The ending slash is required**
|
||||||
|
- `GODOXY_OIDC_CLIENT_ID`
|
||||||
|
- `GODOXY_OIDC_CLIENT_SECRET`
|
||||||
|
- `GODOXY_OIDC_REDIRECT_URL`
|
||||||
|
- `GODOXY_OIDC_SCOPES` _(optional)_
|
||||||
|
- `GODOXY_OIDC_ALLOWED_USERS`
|
||||||
|
|
||||||
|
- Use OpenID Connect to authenticate GoDoxy's WebUI and all your services (SSO)
|
||||||
|
```yaml
|
||||||
|
# default
|
||||||
|
proxy.app.middlewares.oidc:
|
||||||
|
|
||||||
|
# override allowed users
|
||||||
|
proxy.app.middlewares.oidc.allowed_users: user1, user2
|
||||||
|
```
|
||||||
|
|
||||||
- Caddyfile like rules
|
- Caddyfile like rules
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|
Loading…
Add table
Reference in a new issue