implement OIDC middleware

This commit is contained in:
yusing 2025-01-14 03:53:07 +08:00
parent 2af2346e35
commit bb0ee5d7a9
15 changed files with 321 additions and 110 deletions

View file

@ -2,6 +2,7 @@
TZ=ETC/UTC TZ=ETC/UTC
# generate secret with `openssl rand -base64 32` # generate secret with `openssl rand -base64 32`
# used for both user password authentication and OIDC
GODOXY_API_JWT_SECRET= GODOXY_API_JWT_SECRET=
# the JWT token time-to-live # the JWT token time-to-live
@ -11,6 +12,7 @@ GODOXY_API_JWT_TOKEN_TTL=1h
# Important: If using OIDC authentication, the API_USER must match the username # Important: If using OIDC authentication, the API_USER must match the username
# provided by the OIDC provider. # provided by the OIDC provider.
GODOXY_API_USER=admin GODOXY_API_USER=admin
# Password is not required for OIDC authentication
GODOXY_API_PASSWORD=password GODOXY_API_PASSWORD=password
# OIDC Configuration (optional) # OIDC Configuration (optional)

View file

@ -39,7 +39,7 @@ profile:
run: build run: build
sudo setcap CAP_NET_BIND_SERVICE=+eip bin/godoxy sudo setcap CAP_NET_BIND_SERVICE=+eip bin/godoxy
bin/godoxy [ -f .env ] && godotenv -f .env bin/godoxy || bin/godoxy
mtrace: mtrace:
bin/godoxy debug-ls-mtrace > mtrace.json bin/godoxy debug-ls-mtrace > mtrace.json

View file

@ -9,7 +9,6 @@ import (
"time" "time"
"github.com/yusing/go-proxy/internal" "github.com/yusing/go-proxy/internal"
"github.com/yusing/go-proxy/internal/api/v1/auth"
"github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/api/v1/query"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
@ -112,15 +111,6 @@ func main() {
cfg.Start() cfg.Start()
config.WatchChanges() config.WatchChanges()
if !auth.IsEnabled() {
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
} else {
// Initialize authentication providers
if err := auth.Initialize(); err != nil {
logging.Fatal().Err(err).Msg("Failed to initialize authentication providers")
}
}
sig := make(chan os.Signal, 1) sig := make(chan os.Signal, 1)
signal.Notify(sig, syscall.SIGINT) signal.Notify(sig, syscall.SIGINT)
signal.Notify(sig, syscall.SIGTERM) signal.Notify(sig, syscall.SIGTERM)

View file

@ -23,8 +23,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler) mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler)
mux.HandleFunc("GET", "/v1/auth/redirect", auth.AuthRedirectHandler) mux.HandleFunc("GET", "/v1/auth/redirect", auth.APIAuthRedirectHandler)
mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler) mux.HandleFunc("GET", "/v1/auth/callback", auth.APIOIDCCallbackHandler)
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload)) mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))

View file

@ -2,6 +2,7 @@ package auth
import ( import (
"fmt" "fmt"
"net"
"net/http" "net/http"
"time" "time"
@ -9,6 +10,7 @@ import (
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils/strutils" "github.com/yusing/go-proxy/internal/utils/strutils"
) )
@ -23,29 +25,59 @@ type (
} }
) )
// Initialize sets up authentication providers. // init sets up authentication providers.
func Initialize() error { func init() {
if !IsEnabled() {
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
return
}
// Initialize OIDC if configured. // Initialize OIDC if configured.
if common.OIDCIssuerURL != "" { if common.OIDCIssuerURL != "" {
return InitOIDC( if err := initOIDC(
common.OIDCIssuerURL, common.OIDCIssuerURL,
common.OIDCClientID, common.OIDCClientID,
common.OIDCClientSecret, common.OIDCClientSecret,
common.OIDCRedirectURL, common.OIDCRedirectURL,
) ); err != nil {
logging.Fatal().Err(err).Msg("failed to initialize OIDC provider")
}
} }
return nil
} }
func IsEnabled() bool { func IsEnabled() bool {
return common.APIJWTSecret != nil || common.OIDCIssuerURL != "" return common.APIJWTSecret != nil || IsOIDCEnabled()
} }
// AuthRedirectHandler handles redirect to login page or OIDC login base on configuration. func IsOIDCEnabled() bool {
func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) { return common.OIDCIssuerURL != ""
}
// cookieFQDN returns the fully qualified domain name of the request host
// with subdomain stripped.
//
// If the request host does not have a subdomain,
// an empty string is returned
//
// "abc.example.com" -> "example.com"
// "example.com" -> ""
func cookieFQDN(r *http.Request) string {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
parts := strutils.SplitRune(host, '.')
if len(parts) < 2 {
return ""
}
parts[0] = ""
return strutils.JoinRune(parts, '.')
}
// APIAuthRedirectHandler handles API redirect to login page or OIDC login base on configuration.
func APIAuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
switch { switch {
case oauthConfig != nil: case apiOAuth != nil:
RedirectOIDC(w, r) apiOAuth.RedirectOIDC(w, r)
return return
case common.APIJWTSecret != nil: case common.APIJWTSecret != nil:
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
@ -55,7 +87,7 @@ func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
} }
} }
func setAuthenticatedCookie(w http.ResponseWriter, username string) error { func setAuthenticatedCookie(w http.ResponseWriter, r *http.Request, username string) error {
expiresAt := time.Now().Add(common.APIJWTTokenTTL) expiresAt := time.Now().Add(common.APIJWTTokenTTL)
claim := &Claims{ claim := &Claims{
Username: username, Username: username,
@ -72,9 +104,10 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
Name: CookieToken, Name: CookieToken,
Value: tokenStr, Value: tokenStr,
Expires: expiresAt, Expires: expiresAt,
Domain: cookieFQDN(r),
HttpOnly: true, HttpOnly: true,
Secure: true, Secure: true,
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteLaxMode,
Path: "/", Path: "/",
}) })
return nil return nil
@ -84,20 +117,22 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
func LogoutHandler(w http.ResponseWriter, r *http.Request) { func LogoutHandler(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: CookieToken, Name: CookieToken,
Value: "", MaxAge: -1,
Expires: time.Unix(0, 0), Domain: cookieFQDN(r),
HttpOnly: true, HttpOnly: true,
Secure: true, Secure: true,
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteLaxMode,
Path: "/", Path: "/",
}) })
AuthRedirectHandler(w, r) APIAuthRedirectHandler(w, r)
} }
func RequireAuth(next http.HandlerFunc) http.HandlerFunc { func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
if IsEnabled() { if IsEnabled() {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if checkToken(w, r) { if err := CheckToken(w, r); err != nil {
U.RespondError(w, err, http.StatusUnauthorized)
} else {
next(w, r) next(w, r)
} }
} }
@ -105,11 +140,10 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
return next return next
} }
func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) { func CheckToken(w http.ResponseWriter, r *http.Request) error {
tokenCookie, err := r.Cookie(CookieToken) tokenCookie, err := r.Cookie(CookieToken)
if err != nil { if err != nil {
U.RespondError(w, E.New("missing token"), http.StatusUnauthorized) return E.New("missing token")
return false
} }
var claims Claims var claims Claims
token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
@ -118,22 +152,17 @@ func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
} }
return common.APIJWTSecret, nil return common.APIJWTSecret, nil
}) })
switch {
case err != nil:
break
case !token.Valid:
err = E.New("invalid token")
case claims.Username != common.APIUser:
err = E.New("username mismatch").Subject(claims.Username)
case claims.ExpiresAt.Before(time.Now()):
err = E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
}
if err != nil { if err != nil {
U.RespondError(w, err, http.StatusForbidden) return err
return false }
switch {
case !token.Valid:
return E.New("invalid token")
case claims.Username != common.APIUser:
return E.New("username mismatch").Subject(claims.Username)
case claims.ExpiresAt.Before(time.Now()):
return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
} }
return true return nil
} }

View file

@ -13,42 +13,66 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
var ( type OIDCProvider struct {
oauthConfig *oauth2.Config oauthConfig *oauth2.Config
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
oidcVerifier *oidc.IDTokenVerifier oidcVerifier *oidc.IDTokenVerifier
overrideHost bool
}
var (
apiOAuth *OIDCProvider
APIOIDCCallbackHandler http.HandlerFunc
) )
// InitOIDC initializes the OIDC provider. // initOIDC initializes the OIDC provider.
func InitOIDC(issuerURL, clientID, clientSecret, redirectURL string) error { func initOIDC(issuerURL, clientID, clientSecret, redirectURL string) (err error) {
if issuerURL == "" { if issuerURL == "" {
return nil // OIDC not configured return nil // OIDC not configured
} }
apiOAuth, err = NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL)
APIOIDCCallbackHandler = apiOAuth.OIDCCallbackHandler
return
}
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string) (*OIDCProvider, error) {
provider, err := oidc.NewProvider(context.Background(), issuerURL) provider, err := oidc.NewProvider(context.Background(), issuerURL)
if err != nil { if err != nil {
return fmt.Errorf("failed to initialize OIDC provider: %w", err) return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
} }
oidcProvider = provider return &OIDCProvider{
oidcVerifier = provider.Verifier(&oidc.Config{ oauthConfig: &oauth2.Config{
ClientID: clientID, ClientID: clientID,
}) ClientSecret: clientSecret,
RedirectURL: redirectURL,
Endpoint: provider.Endpoint(),
Scopes: strutils.CommaSeperatedList(common.OIDCScopes),
},
oidcProvider: provider,
oidcVerifier: provider.Verifier(&oidc.Config{
ClientID: clientID,
}),
}, nil
}
oauthConfig = &oauth2.Config{ func NewOIDCProviderFromEnv(redirectURL string) (*OIDCProvider, error) {
ClientID: clientID, return NewOIDCProvider(
ClientSecret: clientSecret, common.OIDCIssuerURL,
RedirectURL: redirectURL, common.OIDCClientID,
Endpoint: provider.Endpoint(), common.OIDCClientSecret,
Scopes: strutils.CommaSeperatedList(common.OIDCScopes), redirectURL,
} )
}
return nil func (provider *OIDCProvider) SetOverrideHostEnabled(enabled bool) {
provider.overrideHost = enabled
} }
// RedirectOIDC initiates the OIDC login flow. // RedirectOIDC initiates the OIDC login flow.
func RedirectOIDC(w http.ResponseWriter, r *http.Request) { func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Request) {
if oauthConfig == nil { if provider == nil {
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
return return
} }
@ -59,18 +83,29 @@ func RedirectOIDC(w http.ResponseWriter, r *http.Request) {
Value: state, Value: state,
MaxAge: 300, MaxAge: 300,
HttpOnly: true, HttpOnly: true,
SameSite: http.SameSiteNoneMode, SameSite: http.SameSiteLaxMode,
Secure: true, Secure: true,
Path: "/", Path: "/",
}) })
url := oauthConfig.AuthCodeURL(state) redirURL := provider.oauthConfig.AuthCodeURL(state)
http.Redirect(w, r, url, http.StatusTemporaryRedirect) if provider.overrideHost {
u, err := r.URL.Parse(redirURL)
if err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
return
}
q := u.Query()
q.Set("redirect_uri", "https://"+r.Host+q.Get("redirect_uri"))
u.RawQuery = q.Encode()
redirURL = u.String()
}
http.Redirect(w, r, redirURL, http.StatusTemporaryRedirect)
} }
// OIDCCallbackHandler handles the OIDC callback. // OIDCCallbackHandler handles the OIDC callback.
func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
if oauthConfig == nil { if provider == nil {
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
return return
} }
@ -81,7 +116,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if oidcProvider == nil { if provider.oidcProvider == nil {
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
return return
} }
@ -98,7 +133,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
} }
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
oauth2Token, err := oauthConfig.Exchange(r.Context(), code) oauth2Token, err := provider.oauthConfig.Exchange(r.Context(), code)
if err != nil { if err != nil {
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError) U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
return return
@ -110,7 +145,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
idToken, err := oidcVerifier.Verify(r.Context(), rawIDToken) idToken, err := provider.oidcVerifier.Verify(r.Context(), rawIDToken)
if err != nil { if err != nil {
U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError) U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError)
return return
@ -125,7 +160,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
if err := setAuthenticatedCookie(w, claims.Username); err != nil { if err := setAuthenticatedCookie(w, r, claims.Username); err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError) U.HandleErr(w, r, err, http.StatusInternalServerError)
return return
} }
@ -148,7 +183,7 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) {
} }
// Create test JWT token // Create test JWT token
if err := setAuthenticatedCookie(w, "test-user"); err != nil { if err := setAuthenticatedCookie(w, r, "test-user"); err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError) U.HandleErr(w, r, err, http.StatusInternalServerError)
return return
} }

View file

@ -14,22 +14,22 @@ import (
func setupMockOIDC(t *testing.T) { func setupMockOIDC(t *testing.T) {
t.Helper() t.Helper()
oauthConfig = &oauth2.Config{ apiOAuth = &OIDCProvider{
ClientID: "test-client", oauthConfig: &oauth2.Config{
ClientSecret: "test-secret", ClientID: "test-client",
RedirectURL: "http://localhost/callback", ClientSecret: "test-secret",
Endpoint: oauth2.Endpoint{ RedirectURL: "http://localhost/callback",
AuthURL: "http://mock-provider/auth", Endpoint: oauth2.Endpoint{
TokenURL: "http://mock-provider/token", AuthURL: "http://mock-provider/auth",
TokenURL: "http://mock-provider/token",
},
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}, },
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
} }
} }
func cleanup() { func cleanup() {
oauthConfig = nil apiOAuth = nil
oidcProvider = nil
oidcVerifier = nil
} }
func TestOIDCLoginHandler(t *testing.T) { func TestOIDCLoginHandler(t *testing.T) {
@ -65,13 +65,13 @@ func TestOIDCLoginHandler(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if !tt.configureOAuth { if !tt.configureOAuth {
oauthConfig = nil apiOAuth = nil
} }
req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil) req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
RedirectOIDC(w, req) apiOAuth.RedirectOIDC(w, req)
if got := w.Code; got != tt.wantStatus { if got := w.Code; got != tt.wantStatus {
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus) t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
@ -140,7 +140,7 @@ func TestOIDCCallbackHandler(t *testing.T) {
} }
if !tt.configureOAuth { if !tt.configureOAuth {
oauthConfig = nil apiOAuth = nil
} }
req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil) req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil)
@ -152,7 +152,7 @@ func TestOIDCCallbackHandler(t *testing.T) {
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
OIDCCallbackHandler(w, req) apiOAuth.OIDCCallbackHandler(w, req)
if got := w.Code; got != tt.wantStatus { if got := w.Code; got != tt.wantStatus {
t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus) t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus)
@ -194,7 +194,7 @@ func TestInitOIDC(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Cleanup(cleanup) t.Cleanup(cleanup)
err := InitOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL) err := initOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
} }

View file

@ -37,7 +37,7 @@ func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) {
U.HandleErr(w, r, err, http.StatusUnauthorized) U.HandleErr(w, r, err, http.StatusUnauthorized)
return return
} }
if err := setAuthenticatedCookie(w, creds.Username); err != nil { if err := setAuthenticatedCookie(w, r, creds.Username); err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError) U.HandleErr(w, r, err, http.StatusInternalServerError)
return return
} }

View file

@ -89,7 +89,11 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Then scraper / scanners will know the subdomain is invalid. // Then scraper / scanners will know the subdomain is invalid.
// With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid.
if served := middleware.ServeStaticErrorPageFile(w, r); !served { if served := middleware.ServeStaticErrorPageFile(w, r); !served {
logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request") logger.Err(err).
Str("method", r.Method).
Str("url", r.URL.String()).
Str("remote", r.RemoteAddr).
Msg("request")
errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound) errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound)
if ok { if ok {
w.WriteHeader(http.StatusNotFound) w.WriteHeader(http.StatusNotFound)

View file

@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"net/http" "net/http"
"reflect" "reflect"
"sort"
"strings" "strings"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -26,28 +27,50 @@ type (
name string name string
construct ImplNewFunc construct ImplNewFunc
impl any impl any
// priority is only applied for ReverseProxy.
//
// Middleware compose follows the order of the slice
//
// Default is 10, 0 is the highest
priority int
} }
ByPriority []*Middleware
RequestModifier interface { RequestModifier interface {
before(w http.ResponseWriter, r *http.Request) (proceed bool) before(w http.ResponseWriter, r *http.Request) (proceed bool)
} }
ResponseModifier interface{ modifyResponse(r *http.Response) error } ResponseModifier interface{ modifyResponse(r *http.Response) error }
MiddlewareWithSetup interface{ setup() } MiddlewareWithSetup interface{ setup() }
MiddlewareFinalizer interface{ finalize() } MiddlewareFinalizer interface{ finalize() }
MiddlewareFinalizerWithError interface {
finalize() error
}
MiddlewareWithTracer interface { MiddlewareWithTracer interface {
enableTrace() enableTrace()
getTracer() *Tracer getTracer() *Tracer
} }
) )
const DefaultPriority = 10
func (m ByPriority) Len() int { return len(m) }
func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority }
func NewMiddleware[ImplType any]() *Middleware { func NewMiddleware[ImplType any]() *Middleware {
// type check // type check
switch any(new(ImplType)).(type) { t := any(new(ImplType))
switch t.(type) {
case RequestModifier: case RequestModifier:
case ResponseModifier: case ResponseModifier:
default: default:
panic("must implement RequestModifier or ResponseModifier") panic("must implement RequestModifier or ResponseModifier")
} }
_, hasFinializer := t.(MiddlewareFinalizer)
_, hasFinializerWithError := t.(MiddlewareFinalizerWithError)
if hasFinializer && hasFinializerWithError {
panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive")
}
return &Middleware{ return &Middleware{
name: strings.ToLower(reflect.TypeFor[ImplType]().Name()), name: strings.ToLower(reflect.TypeFor[ImplType]().Name()),
construct: func() any { return new(ImplType) }, construct: func() any { return new(ImplType) },
@ -84,13 +107,29 @@ func (m *Middleware) apply(optsRaw OptionsRaw) E.Error {
if len(optsRaw) == 0 { if len(optsRaw) == 0 {
return nil return nil
} }
priority, ok := optsRaw["priority"].(int)
if ok {
m.priority = priority
// remove priority for deserialization, restore later
delete(optsRaw, "priority")
defer func() {
optsRaw["priority"] = priority
}()
} else {
m.priority = DefaultPriority
}
return utils.Deserialize(optsRaw, m.impl) return utils.Deserialize(optsRaw, m.impl)
} }
func (m *Middleware) finalize() { func (m *Middleware) finalize() error {
if finalizer, ok := m.impl.(MiddlewareFinalizer); ok { if finalizer, ok := m.impl.(MiddlewareFinalizer); ok {
finalizer.finalize() finalizer.finalize()
return nil
} }
if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok {
return finalizer.finalize()
}
return nil
} }
func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
@ -105,7 +144,9 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) {
if err := mid.apply(optsRaw); err != nil { if err := mid.apply(optsRaw); err != nil {
return nil, err return nil, err
} }
mid.finalize() if err := mid.finalize(); err != nil {
return nil, E.From(err)
}
return mid, nil return mid, nil
} }
@ -119,8 +160,9 @@ func (m *Middleware) String() string {
func (m *Middleware) MarshalJSON() ([]byte, error) { func (m *Middleware) MarshalJSON() ([]byte, error) {
return json.MarshalIndent(map[string]any{ return json.MarshalIndent(map[string]any{
"name": m.name, "name": m.name,
"options": m.impl, "options": m.impl,
"priority": m.priority,
}, "", " ") }, "", " ")
} }
@ -193,6 +235,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) (
} }
func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) {
sort.Sort(ByPriority(middlewares))
middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...) middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...)
mid := NewMiddlewareChain(rp.TargetName, middlewares) mid := NewMiddlewareChain(rp.TargetName, middlewares)

View file

@ -0,0 +1,37 @@
package middleware
import (
"net/http"
"strconv"
"strings"
"testing"
. "github.com/yusing/go-proxy/internal/utils/testing"
)
type testPriority struct {
Value int `json:"value"`
}
var test = NewMiddleware[testPriority]()
func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool {
w.Header().Add("Test-Value", strconv.Itoa(t.Value))
return true
}
func TestMiddlewarePriority(t *testing.T) {
priorities := []int{4, 7, 9, 0}
chain := make([]*Middleware, len(priorities))
for i, p := range priorities {
mid, err := test.New(OptionsRaw{
"priority": p,
"value": i,
})
ExpectNoError(t, err)
chain[i] = mid
}
res, err := newMiddlewaresTest(chain, nil)
ExpectNoError(t, err)
ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2")
}

View file

@ -14,6 +14,8 @@ import (
var allMiddlewares = map[string]*Middleware{ var allMiddlewares = map[string]*Middleware{
"redirecthttp": RedirectHTTP, "redirecthttp": RedirectHTTP,
"auth": OIDC,
"request": ModifyRequest, "request": ModifyRequest,
"modifyrequest": ModifyRequest, "modifyrequest": ModifyRequest,
"response": ModifyResponse, "response": ModifyResponse,

View file

@ -0,0 +1,51 @@
package middleware
import (
"net/http"
"github.com/yusing/go-proxy/internal/api/v1/auth"
E "github.com/yusing/go-proxy/internal/error"
)
type oidcMiddleware struct {
oauth *auth.OIDCProvider
authMux *http.ServeMux
}
var OIDC = NewMiddleware[oidcMiddleware]()
const (
OIDCMiddlewareCallbackPath = "/godoxy-auth-oidc/callback"
OIDCLogoutPath = "/logout"
)
func (amw *oidcMiddleware) finalize() error {
if !auth.IsOIDCEnabled() {
return E.New("OIDC not enabled but Auth middleware is used")
}
provider, err := auth.NewOIDCProviderFromEnv(OIDCMiddlewareCallbackPath)
if err != nil {
return err
}
provider.SetOverrideHostEnabled(true)
amw.oauth = provider
amw.authMux = http.NewServeMux()
amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, provider.OIDCCallbackHandler)
amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
})
amw.authMux.HandleFunc("/", provider.RedirectOIDC)
return nil
}
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if err, _ := auth.CheckToken(w, r); err != nil {
amw.authMux.ServeHTTP(w, r)
return false
}
if r.URL.Path == OIDCLogoutPath {
auth.LogoutHandler(w, r)
return false
}
return true
}

View file

@ -127,6 +127,20 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
} }
args.setDefaults() args.setDefaults()
mid, setOptErr := middleware.New(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr
}
return newMiddlewaresTest([]*Middleware{mid}, args)
}
func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) {
if args == nil {
args = new(testArgs)
}
args.setDefaults()
req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader()) req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader())
for k, v := range args.headers { for k, v := range args.headers {
req.Header[k] = v req.Header[k] = v
@ -139,14 +153,8 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E
rr.parent = http.DefaultTransport rr.parent = http.DefaultTransport
} }
rp := reverseproxy.NewReverseProxy(middleware.name, args.upstreamURL, rr) rp := reverseproxy.NewReverseProxy("test", args.upstreamURL, rr)
patchReverseProxy(rp, middlewares)
mid, setOptErr := middleware.New(args.middlewareOpt)
if setOptErr != nil {
return nil, setOptErr
}
patchReverseProxy(rp, []*Middleware{mid})
rp.ServeHTTP(w, req) rp.ServeHTTP(w, req)
resp := w.Result() resp := w.Result()

View file

@ -73,6 +73,16 @@ GoDoxy v0.8.2 expected changes
* Connection #0 to host localhost left intact * Connection #0 to host localhost left intact
``` ```
- **Thanks [polds](https://github.com/polds)**
Support WebUI authentication via OIDC by setting these environment variables:
- `GODOXY_API_USER`
- `GODOXY_API_JWT_SECRET`
- `GODOXY_OIDC_ISSUER_URL`
- `GODOXY_OIDC_CLIENT_ID`
- `GODOXY_OIDC_CLIENT_SECRET`
- `GODOXY_OIDC_REDIRECT_URL`
- `GODOXY_OIDC_SCOPES`
- Caddyfile like rules - Caddyfile like rules
```yaml ```yaml