diff --git a/.env.example b/.env.example index 19d7c24..b639486 100644 --- a/.env.example +++ b/.env.example @@ -2,6 +2,7 @@ TZ=ETC/UTC # generate secret with `openssl rand -base64 32` +# used for both user password authentication and OIDC GODOXY_API_JWT_SECRET= # the JWT token time-to-live @@ -11,6 +12,7 @@ GODOXY_API_JWT_TOKEN_TTL=1h # Important: If using OIDC authentication, the API_USER must match the username # provided by the OIDC provider. GODOXY_API_USER=admin +# Password is not required for OIDC authentication GODOXY_API_PASSWORD=password # OIDC Configuration (optional) diff --git a/Makefile b/Makefile index 94d0ce9..8c6c68c 100755 --- a/Makefile +++ b/Makefile @@ -39,7 +39,7 @@ profile: run: build sudo setcap CAP_NET_BIND_SERVICE=+eip bin/godoxy - bin/godoxy + [ -f .env ] && godotenv -f .env bin/godoxy || bin/godoxy mtrace: bin/godoxy debug-ls-mtrace > mtrace.json diff --git a/cmd/main.go b/cmd/main.go index c1a88eb..4239ed9 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -9,7 +9,6 @@ import ( "time" "github.com/yusing/go-proxy/internal" - "github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" @@ -112,15 +111,6 @@ func main() { cfg.Start() config.WatchChanges() - if !auth.IsEnabled() { - logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication") - } else { - // Initialize authentication providers - if err := auth.Initialize(); err != nil { - logging.Fatal().Err(err).Msg("Failed to initialize authentication providers") - } - } - sig := make(chan os.Signal, 1) signal.Notify(sig, syscall.SIGINT) signal.Notify(sig, syscall.SIGTERM) diff --git a/internal/api/handler.go b/internal/api/handler.go index cab7a1f..a1dc91d 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -23,8 +23,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler { mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler) - mux.HandleFunc("GET", "/v1/auth/redirect", auth.AuthRedirectHandler) - mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler) + mux.HandleFunc("GET", "/v1/auth/redirect", auth.APIAuthRedirectHandler) + mux.HandleFunc("GET", "/v1/auth/callback", auth.APIOIDCCallbackHandler) mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload)) diff --git a/internal/api/v1/auth/auth.go b/internal/api/v1/auth/auth.go index 532d339..49ba85a 100644 --- a/internal/api/v1/auth/auth.go +++ b/internal/api/v1/auth/auth.go @@ -2,6 +2,7 @@ package auth import ( "fmt" + "net" "net/http" "time" @@ -9,6 +10,7 @@ import ( U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/utils/strutils" ) @@ -23,29 +25,59 @@ type ( } ) -// Initialize sets up authentication providers. -func Initialize() error { +// init sets up authentication providers. +func init() { + if !IsEnabled() { + logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication") + return + } // Initialize OIDC if configured. if common.OIDCIssuerURL != "" { - return InitOIDC( + if err := initOIDC( common.OIDCIssuerURL, common.OIDCClientID, common.OIDCClientSecret, common.OIDCRedirectURL, - ) + ); err != nil { + logging.Fatal().Err(err).Msg("failed to initialize OIDC provider") + } } - return nil } func IsEnabled() bool { - return common.APIJWTSecret != nil || common.OIDCIssuerURL != "" + return common.APIJWTSecret != nil || IsOIDCEnabled() } -// AuthRedirectHandler handles redirect to login page or OIDC login base on configuration. -func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) { +func IsOIDCEnabled() bool { + return common.OIDCIssuerURL != "" +} + +// cookieFQDN returns the fully qualified domain name of the request host +// with subdomain stripped. +// +// If the request host does not have a subdomain, +// an empty string is returned +// +// "abc.example.com" -> "example.com" +// "example.com" -> "" +func cookieFQDN(r *http.Request) string { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } + parts := strutils.SplitRune(host, '.') + if len(parts) < 2 { + return "" + } + parts[0] = "" + return strutils.JoinRune(parts, '.') +} + +// APIAuthRedirectHandler handles API redirect to login page or OIDC login base on configuration. +func APIAuthRedirectHandler(w http.ResponseWriter, r *http.Request) { switch { - case oauthConfig != nil: - RedirectOIDC(w, r) + case apiOAuth != nil: + apiOAuth.RedirectOIDC(w, r) return case common.APIJWTSecret != nil: http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) @@ -55,7 +87,7 @@ func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) { } } -func setAuthenticatedCookie(w http.ResponseWriter, username string) error { +func setAuthenticatedCookie(w http.ResponseWriter, r *http.Request, username string) error { expiresAt := time.Now().Add(common.APIJWTTokenTTL) claim := &Claims{ Username: username, @@ -72,9 +104,10 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error { Name: CookieToken, Value: tokenStr, Expires: expiresAt, + Domain: cookieFQDN(r), HttpOnly: true, Secure: true, - SameSite: http.SameSiteStrictMode, + SameSite: http.SameSiteLaxMode, Path: "/", }) return nil @@ -84,20 +117,22 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error { func LogoutHandler(w http.ResponseWriter, r *http.Request) { http.SetCookie(w, &http.Cookie{ Name: CookieToken, - Value: "", - Expires: time.Unix(0, 0), + MaxAge: -1, + Domain: cookieFQDN(r), HttpOnly: true, Secure: true, - SameSite: http.SameSiteStrictMode, + SameSite: http.SameSiteLaxMode, Path: "/", }) - AuthRedirectHandler(w, r) + APIAuthRedirectHandler(w, r) } func RequireAuth(next http.HandlerFunc) http.HandlerFunc { if IsEnabled() { return func(w http.ResponseWriter, r *http.Request) { - if checkToken(w, r) { + if err := CheckToken(w, r); err != nil { + U.RespondError(w, err, http.StatusUnauthorized) + } else { next(w, r) } } @@ -105,11 +140,10 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc { return next } -func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) { +func CheckToken(w http.ResponseWriter, r *http.Request) error { tokenCookie, err := r.Cookie(CookieToken) if err != nil { - U.RespondError(w, E.New("missing token"), http.StatusUnauthorized) - return false + return E.New("missing token") } var claims Claims token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { @@ -118,22 +152,17 @@ func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) { } return common.APIJWTSecret, nil }) - - switch { - case err != nil: - break - case !token.Valid: - err = E.New("invalid token") - case claims.Username != common.APIUser: - err = E.New("username mismatch").Subject(claims.Username) - case claims.ExpiresAt.Before(time.Now()): - err = E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time)) - } - if err != nil { - U.RespondError(w, err, http.StatusForbidden) - return false + return err + } + switch { + case !token.Valid: + return E.New("invalid token") + case claims.Username != common.APIUser: + return E.New("username mismatch").Subject(claims.Username) + case claims.ExpiresAt.Before(time.Now()): + return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time)) } - return true + return nil } diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index 3b1886a..66ebeab 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -13,42 +13,66 @@ import ( "golang.org/x/oauth2" ) -var ( +type OIDCProvider struct { oauthConfig *oauth2.Config oidcProvider *oidc.Provider oidcVerifier *oidc.IDTokenVerifier + overrideHost bool +} + +var ( + apiOAuth *OIDCProvider + APIOIDCCallbackHandler http.HandlerFunc ) -// InitOIDC initializes the OIDC provider. -func InitOIDC(issuerURL, clientID, clientSecret, redirectURL string) error { +// initOIDC initializes the OIDC provider. +func initOIDC(issuerURL, clientID, clientSecret, redirectURL string) (err error) { if issuerURL == "" { return nil // OIDC not configured } + apiOAuth, err = NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL) + APIOIDCCallbackHandler = apiOAuth.OIDCCallbackHandler + return +} + +func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string) (*OIDCProvider, error) { provider, err := oidc.NewProvider(context.Background(), issuerURL) if err != nil { - return fmt.Errorf("failed to initialize OIDC provider: %w", err) + return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) } - oidcProvider = provider - oidcVerifier = provider.Verifier(&oidc.Config{ - ClientID: clientID, - }) + return &OIDCProvider{ + oauthConfig: &oauth2.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + RedirectURL: redirectURL, + Endpoint: provider.Endpoint(), + Scopes: strutils.CommaSeperatedList(common.OIDCScopes), + }, + oidcProvider: provider, + oidcVerifier: provider.Verifier(&oidc.Config{ + ClientID: clientID, + }), + }, nil +} - oauthConfig = &oauth2.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - RedirectURL: redirectURL, - Endpoint: provider.Endpoint(), - Scopes: strutils.CommaSeperatedList(common.OIDCScopes), - } +func NewOIDCProviderFromEnv(redirectURL string) (*OIDCProvider, error) { + return NewOIDCProvider( + common.OIDCIssuerURL, + common.OIDCClientID, + common.OIDCClientSecret, + redirectURL, + ) +} - return nil +func (provider *OIDCProvider) SetOverrideHostEnabled(enabled bool) { + provider.overrideHost = enabled } // RedirectOIDC initiates the OIDC login flow. -func RedirectOIDC(w http.ResponseWriter, r *http.Request) { - if oauthConfig == nil { +func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Request) { + if provider == nil { U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) return } @@ -59,18 +83,29 @@ func RedirectOIDC(w http.ResponseWriter, r *http.Request) { Value: state, MaxAge: 300, HttpOnly: true, - SameSite: http.SameSiteNoneMode, + SameSite: http.SameSiteLaxMode, Secure: true, Path: "/", }) - url := oauthConfig.AuthCodeURL(state) - http.Redirect(w, r, url, http.StatusTemporaryRedirect) + redirURL := provider.oauthConfig.AuthCodeURL(state) + if provider.overrideHost { + u, err := r.URL.Parse(redirURL) + if err != nil { + U.HandleErr(w, r, err, http.StatusInternalServerError) + return + } + q := u.Query() + q.Set("redirect_uri", "https://"+r.Host+q.Get("redirect_uri")) + u.RawQuery = q.Encode() + redirURL = u.String() + } + http.Redirect(w, r, redirURL, http.StatusTemporaryRedirect) } // OIDCCallbackHandler handles the OIDC callback. -func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { - if oauthConfig == nil { +func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { + if provider == nil { U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) return } @@ -81,7 +116,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { return } - if oidcProvider == nil { + if provider.oidcProvider == nil { U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) return } @@ -98,7 +133,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { } code := r.URL.Query().Get("code") - oauth2Token, err := oauthConfig.Exchange(r.Context(), code) + oauth2Token, err := provider.oauthConfig.Exchange(r.Context(), code) if err != nil { U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError) return @@ -110,7 +145,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { return } - idToken, err := oidcVerifier.Verify(r.Context(), rawIDToken) + idToken, err := provider.oidcVerifier.Verify(r.Context(), rawIDToken) if err != nil { U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError) return @@ -125,7 +160,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { return } - if err := setAuthenticatedCookie(w, claims.Username); err != nil { + if err := setAuthenticatedCookie(w, r, claims.Username); err != nil { U.HandleErr(w, r, err, http.StatusInternalServerError) return } @@ -148,7 +183,7 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) { } // Create test JWT token - if err := setAuthenticatedCookie(w, "test-user"); err != nil { + if err := setAuthenticatedCookie(w, r, "test-user"); err != nil { U.HandleErr(w, r, err, http.StatusInternalServerError) return } diff --git a/internal/api/v1/auth/oidc_test.go b/internal/api/v1/auth/oidc_test.go index 1e1f986..76ed2d0 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/api/v1/auth/oidc_test.go @@ -14,22 +14,22 @@ import ( func setupMockOIDC(t *testing.T) { t.Helper() - oauthConfig = &oauth2.Config{ - ClientID: "test-client", - ClientSecret: "test-secret", - RedirectURL: "http://localhost/callback", - Endpoint: oauth2.Endpoint{ - AuthURL: "http://mock-provider/auth", - TokenURL: "http://mock-provider/token", + apiOAuth = &OIDCProvider{ + oauthConfig: &oauth2.Config{ + ClientID: "test-client", + ClientSecret: "test-secret", + RedirectURL: "http://localhost/callback", + Endpoint: oauth2.Endpoint{ + AuthURL: "http://mock-provider/auth", + TokenURL: "http://mock-provider/token", + }, + Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, }, - Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, } } func cleanup() { - oauthConfig = nil - oidcProvider = nil - oidcVerifier = nil + apiOAuth = nil } func TestOIDCLoginHandler(t *testing.T) { @@ -65,13 +65,13 @@ func TestOIDCLoginHandler(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if !tt.configureOAuth { - oauthConfig = nil + apiOAuth = nil } req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil) w := httptest.NewRecorder() - RedirectOIDC(w, req) + apiOAuth.RedirectOIDC(w, req) if got := w.Code; got != tt.wantStatus { t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus) @@ -140,7 +140,7 @@ func TestOIDCCallbackHandler(t *testing.T) { } if !tt.configureOAuth { - oauthConfig = nil + apiOAuth = nil } req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil) @@ -152,7 +152,7 @@ func TestOIDCCallbackHandler(t *testing.T) { } w := httptest.NewRecorder() - OIDCCallbackHandler(w, req) + apiOAuth.OIDCCallbackHandler(w, req) if got := w.Code; got != tt.wantStatus { t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus) @@ -194,7 +194,7 @@ func TestInitOIDC(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Cleanup(cleanup) - err := InitOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL) + err := initOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL) if (err != nil) != tt.wantErr { t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/internal/api/v1/auth/userpass.go b/internal/api/v1/auth/userpass.go index 6d00e6a..6e3c4c6 100644 --- a/internal/api/v1/auth/userpass.go +++ b/internal/api/v1/auth/userpass.go @@ -37,7 +37,7 @@ func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) { U.HandleErr(w, r, err, http.StatusUnauthorized) return } - if err := setAuthenticatedCookie(w, creds.Username); err != nil { + if err := setAuthenticatedCookie(w, r, creds.Username); err != nil { U.HandleErr(w, r, err, http.StatusInternalServerError) return } diff --git a/internal/entrypoint/entrypoint.go b/internal/entrypoint/entrypoint.go index a2f5c0a..faed489 100644 --- a/internal/entrypoint/entrypoint.go +++ b/internal/entrypoint/entrypoint.go @@ -89,7 +89,11 @@ func (ep *Entrypoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Then scraper / scanners will know the subdomain is invalid. // With StatusNotFound, they won't know whether it's the path, or the subdomain that is invalid. if served := middleware.ServeStaticErrorPageFile(w, r); !served { - logger.Err(err).Str("method", r.Method).Str("url", r.URL.String()).Msg("request") + logger.Err(err). + Str("method", r.Method). + Str("url", r.URL.String()). + Str("remote", r.RemoteAddr). + Msg("request") errorPage, ok := errorpage.GetErrorPageByStatus(http.StatusNotFound) if ok { w.WriteHeader(http.StatusNotFound) diff --git a/internal/net/http/middleware/middleware.go b/internal/net/http/middleware/middleware.go index 78410d3..a271972 100644 --- a/internal/net/http/middleware/middleware.go +++ b/internal/net/http/middleware/middleware.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "reflect" + "sort" "strings" E "github.com/yusing/go-proxy/internal/error" @@ -26,28 +27,50 @@ type ( name string construct ImplNewFunc impl any + // priority is only applied for ReverseProxy. + // + // Middleware compose follows the order of the slice + // + // Default is 10, 0 is the highest + priority int } + ByPriority []*Middleware RequestModifier interface { before(w http.ResponseWriter, r *http.Request) (proceed bool) } - ResponseModifier interface{ modifyResponse(r *http.Response) error } - MiddlewareWithSetup interface{ setup() } - MiddlewareFinalizer interface{ finalize() } + ResponseModifier interface{ modifyResponse(r *http.Response) error } + MiddlewareWithSetup interface{ setup() } + MiddlewareFinalizer interface{ finalize() } + MiddlewareFinalizerWithError interface { + finalize() error + } MiddlewareWithTracer interface { enableTrace() getTracer() *Tracer } ) +const DefaultPriority = 10 + +func (m ByPriority) Len() int { return len(m) } +func (m ByPriority) Swap(i, j int) { m[i], m[j] = m[j], m[i] } +func (m ByPriority) Less(i, j int) bool { return m[i].priority < m[j].priority } + func NewMiddleware[ImplType any]() *Middleware { // type check - switch any(new(ImplType)).(type) { + t := any(new(ImplType)) + switch t.(type) { case RequestModifier: case ResponseModifier: default: panic("must implement RequestModifier or ResponseModifier") } + _, hasFinializer := t.(MiddlewareFinalizer) + _, hasFinializerWithError := t.(MiddlewareFinalizerWithError) + if hasFinializer && hasFinializerWithError { + panic("MiddlewareFinalizer and MiddlewareFinalizerWithError are mutually exclusive") + } return &Middleware{ name: strings.ToLower(reflect.TypeFor[ImplType]().Name()), construct: func() any { return new(ImplType) }, @@ -84,13 +107,29 @@ func (m *Middleware) apply(optsRaw OptionsRaw) E.Error { if len(optsRaw) == 0 { return nil } + priority, ok := optsRaw["priority"].(int) + if ok { + m.priority = priority + // remove priority for deserialization, restore later + delete(optsRaw, "priority") + defer func() { + optsRaw["priority"] = priority + }() + } else { + m.priority = DefaultPriority + } return utils.Deserialize(optsRaw, m.impl) } -func (m *Middleware) finalize() { +func (m *Middleware) finalize() error { if finalizer, ok := m.impl.(MiddlewareFinalizer); ok { finalizer.finalize() + return nil } + if finalizer, ok := m.impl.(MiddlewareFinalizerWithError); ok { + return finalizer.finalize() + } + return nil } func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { @@ -105,7 +144,9 @@ func (m *Middleware) New(optsRaw OptionsRaw) (*Middleware, E.Error) { if err := mid.apply(optsRaw); err != nil { return nil, err } - mid.finalize() + if err := mid.finalize(); err != nil { + return nil, E.From(err) + } return mid, nil } @@ -119,8 +160,9 @@ func (m *Middleware) String() string { func (m *Middleware) MarshalJSON() ([]byte, error) { return json.MarshalIndent(map[string]any{ - "name": m.name, - "options": m.impl, + "name": m.name, + "options": m.impl, + "priority": m.priority, }, "", " ") } @@ -193,6 +235,7 @@ func PatchReverseProxy(rp *ReverseProxy, middlewaresMap map[string]OptionsRaw) ( } func patchReverseProxy(rp *ReverseProxy, middlewares []*Middleware) { + sort.Sort(ByPriority(middlewares)) middlewares = append([]*Middleware{newSetUpstreamHeaders(rp)}, middlewares...) mid := NewMiddlewareChain(rp.TargetName, middlewares) diff --git a/internal/net/http/middleware/middleware_test.go b/internal/net/http/middleware/middleware_test.go new file mode 100644 index 0000000..5b6e521 --- /dev/null +++ b/internal/net/http/middleware/middleware_test.go @@ -0,0 +1,37 @@ +package middleware + +import ( + "net/http" + "strconv" + "strings" + "testing" + + . "github.com/yusing/go-proxy/internal/utils/testing" +) + +type testPriority struct { + Value int `json:"value"` +} + +var test = NewMiddleware[testPriority]() + +func (t testPriority) before(w http.ResponseWriter, r *http.Request) bool { + w.Header().Add("Test-Value", strconv.Itoa(t.Value)) + return true +} + +func TestMiddlewarePriority(t *testing.T) { + priorities := []int{4, 7, 9, 0} + chain := make([]*Middleware, len(priorities)) + for i, p := range priorities { + mid, err := test.New(OptionsRaw{ + "priority": p, + "value": i, + }) + ExpectNoError(t, err) + chain[i] = mid + } + res, err := newMiddlewaresTest(chain, nil) + ExpectNoError(t, err) + ExpectEqual(t, strings.Join(res.ResponseHeaders["Test-Value"], ","), "3,0,1,2") +} diff --git a/internal/net/http/middleware/middlewares.go b/internal/net/http/middleware/middlewares.go index cfea2c3..0477601 100644 --- a/internal/net/http/middleware/middlewares.go +++ b/internal/net/http/middleware/middlewares.go @@ -14,6 +14,8 @@ import ( var allMiddlewares = map[string]*Middleware{ "redirecthttp": RedirectHTTP, + "auth": OIDC, + "request": ModifyRequest, "modifyrequest": ModifyRequest, "response": ModifyResponse, diff --git a/internal/net/http/middleware/oidc.go b/internal/net/http/middleware/oidc.go new file mode 100644 index 0000000..5d23f25 --- /dev/null +++ b/internal/net/http/middleware/oidc.go @@ -0,0 +1,51 @@ +package middleware + +import ( + "net/http" + + "github.com/yusing/go-proxy/internal/api/v1/auth" + E "github.com/yusing/go-proxy/internal/error" +) + +type oidcMiddleware struct { + oauth *auth.OIDCProvider + authMux *http.ServeMux +} + +var OIDC = NewMiddleware[oidcMiddleware]() + +const ( + OIDCMiddlewareCallbackPath = "/godoxy-auth-oidc/callback" + OIDCLogoutPath = "/logout" +) + +func (amw *oidcMiddleware) finalize() error { + if !auth.IsOIDCEnabled() { + return E.New("OIDC not enabled but Auth middleware is used") + } + provider, err := auth.NewOIDCProviderFromEnv(OIDCMiddlewareCallbackPath) + if err != nil { + return err + } + provider.SetOverrideHostEnabled(true) + amw.oauth = provider + amw.authMux = http.NewServeMux() + amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, provider.OIDCCallbackHandler) + amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + }) + amw.authMux.HandleFunc("/", provider.RedirectOIDC) + return nil +} + +func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { + if err, _ := auth.CheckToken(w, r); err != nil { + amw.authMux.ServeHTTP(w, r) + return false + } + if r.URL.Path == OIDCLogoutPath { + auth.LogoutHandler(w, r) + return false + } + return true +} diff --git a/internal/net/http/middleware/test_utils.go b/internal/net/http/middleware/test_utils.go index dceeb39..9c8ce3a 100644 --- a/internal/net/http/middleware/test_utils.go +++ b/internal/net/http/middleware/test_utils.go @@ -127,6 +127,20 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E } args.setDefaults() + mid, setOptErr := middleware.New(args.middlewareOpt) + if setOptErr != nil { + return nil, setOptErr + } + + return newMiddlewaresTest([]*Middleware{mid}, args) +} + +func newMiddlewaresTest(middlewares []*Middleware, args *testArgs) (*TestResult, E.Error) { + if args == nil { + args = new(testArgs) + } + args.setDefaults() + req := httptest.NewRequest(args.reqMethod, args.reqURL.String(), args.bodyReader()) for k, v := range args.headers { req.Header[k] = v @@ -139,14 +153,8 @@ func newMiddlewareTest(middleware *Middleware, args *testArgs) (*TestResult, E.E rr.parent = http.DefaultTransport } - rp := reverseproxy.NewReverseProxy(middleware.name, args.upstreamURL, rr) - - mid, setOptErr := middleware.New(args.middlewareOpt) - if setOptErr != nil { - return nil, setOptErr - } - - patchReverseProxy(rp, []*Middleware{mid}) + rp := reverseproxy.NewReverseProxy("test", args.upstreamURL, rr) + patchReverseProxy(rp, middlewares) rp.ServeHTTP(w, req) resp := w.Result() diff --git a/next-release.md b/next-release.md index ae1989a..515ec21 100644 --- a/next-release.md +++ b/next-release.md @@ -73,6 +73,16 @@ GoDoxy v0.8.2 expected changes * Connection #0 to host localhost left intact ``` +- **Thanks [polds](https://github.com/polds)** + Support WebUI authentication via OIDC by setting these environment variables: + - `GODOXY_API_USER` + - `GODOXY_API_JWT_SECRET` + - `GODOXY_OIDC_ISSUER_URL` + - `GODOXY_OIDC_CLIENT_ID` + - `GODOXY_OIDC_CLIENT_SECRET` + - `GODOXY_OIDC_REDIRECT_URL` + - `GODOXY_OIDC_SCOPES` + - Caddyfile like rules ```yaml