diff --git a/.env.example b/.env.example index b639486..7588c4a 100644 --- a/.env.example +++ b/.env.example @@ -1,20 +1,15 @@ # set timezone to get correct log timestamp TZ=ETC/UTC +# API/WebUI user password login credentials (optional) +# These fields are not required for OIDC authentication +GODOXY_API_USER=admin +GODOXY_API_PASSWORD=password # generate secret with `openssl rand -base64 32` -# used for both user password authentication and OIDC GODOXY_API_JWT_SECRET= - # the JWT token time-to-live GODOXY_API_JWT_TOKEN_TTL=1h -# API/WebUI login credentials -# Important: If using OIDC authentication, the API_USER must match the username -# provided by the OIDC provider. -GODOXY_API_USER=admin -# Password is not required for OIDC authentication -GODOXY_API_PASSWORD=password - # OIDC Configuration (optional) # Uncomment and configure these values to enable OIDC authentication. # GODOXY_OIDC_ISSUER_URL=https://accounts.google.com @@ -24,6 +19,8 @@ GODOXY_API_PASSWORD=password # GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback # Comma-separated list of scopes # GODOXY_OIDC_SCOPES=openid, profile, email +# Comma-separated list of allowed users +# GODOXY_OIDC_ALLOWED_USERS=user1,user2 # Proxy listening address GODOXY_HTTP_ADDR=:80 diff --git a/cmd/main.go b/cmd/main.go index 4239ed9..07981ef 100755 --- a/cmd/main.go +++ b/cmd/main.go @@ -9,6 +9,7 @@ import ( "time" "github.com/yusing/go-proxy/internal" + "github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/config" @@ -108,6 +109,10 @@ func main() { return } + if err := auth.Initialize(); err != nil { + logging.Fatal().Err(err).Msg("failed to initialize authentication") + } + cfg.Start() config.WatchChanges() diff --git a/internal/api/handler.go b/internal/api/handler.go index a1dc91d..7888fc2 100644 --- a/internal/api/handler.go +++ b/internal/api/handler.go @@ -1,43 +1,44 @@ package api import ( - "net" "net/http" v1 "github.com/yusing/go-proxy/internal/api/v1" "github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/favicon" - . "github.com/yusing/go-proxy/internal/api/v1/utils" - "github.com/yusing/go-proxy/internal/common" config "github.com/yusing/go-proxy/internal/config/types" + "github.com/yusing/go-proxy/internal/utils/strutils" ) type ServeMux struct{ *http.ServeMux } -func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) { - mux.ServeMux.HandleFunc(method+" "+endpoint, checkHost(handler)) +func (mux ServeMux) HandleFunc(methods, endpoint string, handler http.HandlerFunc) { + for _, m := range strutils.CommaSeperatedList(methods) { + mux.ServeMux.HandleFunc(m+" "+endpoint, handler) + } } func NewHandler(cfg config.ConfigInstance) http.Handler { mux := ServeMux{http.NewServeMux()} mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1/version", v1.GetVersion) - mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler) - mux.HandleFunc("GET", "/v1/auth/redirect", auth.APIAuthRedirectHandler) - mux.HandleFunc("GET", "/v1/auth/callback", auth.APIOIDCCallbackHandler) - mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler) - mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload)) mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent)) - mux.HandleFunc("POST", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) - mux.HandleFunc("PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) + mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile) mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats)) mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS)) mux.HandleFunc("GET", "/v1/favicon/{alias}", auth.RequireAuth(favicon.GetFavIcon)) + + defaultAuth := auth.GetDefaultAuth() + if defaultAuth != nil { + mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage) + mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler) + mux.HandleFunc("GET,POST", "/v1/auth/logout", auth.LogoutCallbackHandler(defaultAuth)) + } return mux } @@ -46,20 +47,3 @@ func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w handler(cfg, w, r) } } - -// allow only requests to API server with localhost. -func checkHost(f http.HandlerFunc) http.HandlerFunc { - if common.IsDebug { - return f - } - return func(w http.ResponseWriter, r *http.Request) { - host, _, _ := net.SplitHostPort(r.RemoteAddr) - if host != "127.0.0.1" && host != "localhost" && host != "[::1]" { - LogWarn(r).Msgf("blocked API request from %s", host) - http.Error(w, "forbidden", http.StatusForbidden) - return - } - LogDebug(r).Interface("headers", r.Header).Msg("API request") - f(w, r) - } -} diff --git a/internal/api/v1/auth/auth.go b/internal/api/v1/auth/auth.go index 49ba85a..1a32bab 100644 --- a/internal/api/v1/auth/auth.go +++ b/internal/api/v1/auth/auth.go @@ -1,47 +1,35 @@ package auth import ( - "fmt" - "net" "net/http" - "time" - "github.com/golang-jwt/jwt/v5" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" - E "github.com/yusing/go-proxy/internal/error" "github.com/yusing/go-proxy/internal/logging" - "github.com/yusing/go-proxy/internal/utils/strutils" ) -type ( - Credentials struct { - Username string `json:"username"` - Password string `json:"password"` - } - Claims struct { - Username string `json:"username"` - jwt.RegisteredClaims - } -) +var defaultAuth Provider -// init sets up authentication providers. -func init() { +// Initialize sets up authentication providers. +func Initialize() error { if !IsEnabled() { logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication") - return + return nil } + + var err error // Initialize OIDC if configured. if common.OIDCIssuerURL != "" { - if err := initOIDC( - common.OIDCIssuerURL, - common.OIDCClientID, - common.OIDCClientSecret, - common.OIDCRedirectURL, - ); err != nil { - logging.Fatal().Err(err).Msg("failed to initialize OIDC provider") - } + defaultAuth, err = NewOIDCProviderFromEnv() + } else { + defaultAuth, err = NewUserPassAuthFromEnv() } + + return err +} + +func GetDefaultAuth() Provider { + return defaultAuth } func IsEnabled() bool { @@ -52,85 +40,10 @@ func IsOIDCEnabled() bool { return common.OIDCIssuerURL != "" } -// cookieFQDN returns the fully qualified domain name of the request host -// with subdomain stripped. -// -// If the request host does not have a subdomain, -// an empty string is returned -// -// "abc.example.com" -> "example.com" -// "example.com" -> "" -func cookieFQDN(r *http.Request) string { - host, _, err := net.SplitHostPort(r.Host) - if err != nil { - host = r.Host - } - parts := strutils.SplitRune(host, '.') - if len(parts) < 2 { - return "" - } - parts[0] = "" - return strutils.JoinRune(parts, '.') -} - -// APIAuthRedirectHandler handles API redirect to login page or OIDC login base on configuration. -func APIAuthRedirectHandler(w http.ResponseWriter, r *http.Request) { - switch { - case apiOAuth != nil: - apiOAuth.RedirectOIDC(w, r) - return - case common.APIJWTSecret != nil: - http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) - return - default: - http.Redirect(w, r, "/", http.StatusTemporaryRedirect) - } -} - -func setAuthenticatedCookie(w http.ResponseWriter, r *http.Request, username string) error { - expiresAt := time.Now().Add(common.APIJWTTokenTTL) - claim := &Claims{ - Username: username, - RegisteredClaims: jwt.RegisteredClaims{ - ExpiresAt: jwt.NewNumericDate(expiresAt), - }, - } - token := jwt.NewWithClaims(jwt.SigningMethodHS512, claim) - tokenStr, err := token.SignedString(common.APIJWTSecret) - if err != nil { - return err - } - http.SetCookie(w, &http.Cookie{ - Name: CookieToken, - Value: tokenStr, - Expires: expiresAt, - Domain: cookieFQDN(r), - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteLaxMode, - Path: "/", - }) - return nil -} - -// LogoutHandler clear authentication cookie and redirect to login page. -func LogoutHandler(w http.ResponseWriter, r *http.Request) { - http.SetCookie(w, &http.Cookie{ - Name: CookieToken, - MaxAge: -1, - Domain: cookieFQDN(r), - HttpOnly: true, - Secure: true, - SameSite: http.SameSiteLaxMode, - Path: "/", - }) - APIAuthRedirectHandler(w, r) -} - func RequireAuth(next http.HandlerFunc) http.HandlerFunc { if IsEnabled() { return func(w http.ResponseWriter, r *http.Request) { - if err := CheckToken(w, r); err != nil { + if err := defaultAuth.CheckToken(w, r); err != nil { U.RespondError(w, err, http.StatusUnauthorized) } else { next(w, r) @@ -139,30 +52,3 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc { } return next } - -func CheckToken(w http.ResponseWriter, r *http.Request) error { - tokenCookie, err := r.Cookie(CookieToken) - if err != nil { - return E.New("missing token") - } - var claims Claims - token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { - if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) - } - return common.APIJWTSecret, nil - }) - if err != nil { - return err - } - switch { - case !token.Valid: - return E.New("invalid token") - case claims.Username != common.APIUser: - return E.New("username mismatch").Subject(claims.Username) - case claims.ExpiresAt.Before(time.Now()): - return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time)) - } - - return nil -} diff --git a/internal/api/v1/auth/cookies.go b/internal/api/v1/auth/cookies.go deleted file mode 100644 index c6d7386..0000000 --- a/internal/api/v1/auth/cookies.go +++ /dev/null @@ -1,6 +0,0 @@ -package auth - -const ( - CookieToken = "godoxy_token" - CookieOauthState = "godoxy_oauth_state" -) diff --git a/internal/api/v1/auth/oidc.go b/internal/api/v1/auth/oidc.go index 66ebeab..25b5b82 100644 --- a/internal/api/v1/auth/oidc.go +++ b/internal/api/v1/auth/oidc.go @@ -2,8 +2,13 @@ package auth import ( "context" + "crypto/rand" + "encoding/base64" + "errors" "fmt" "net/http" + "slices" + "time" "github.com/coreos/go-oidc/v3/oidc" U "github.com/yusing/go-proxy/internal/api/v1/utils" @@ -17,26 +22,17 @@ type OIDCProvider struct { oauthConfig *oauth2.Config oidcProvider *oidc.Provider oidcVerifier *oidc.IDTokenVerifier + allowedUsers []string overrideHost bool } -var ( - apiOAuth *OIDCProvider - APIOIDCCallbackHandler http.HandlerFunc -) +const CookieOauthState = "godoxy_oidc_state" -// initOIDC initializes the OIDC provider. -func initOIDC(issuerURL, clientID, clientSecret, redirectURL string) (err error) { - if issuerURL == "" { - return nil // OIDC not configured +func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) { + if len(allowedUsers) == 0 { + return nil, errors.New("OIDC allowed users must not be empty") } - apiOAuth, err = NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL) - APIOIDCCallbackHandler = apiOAuth.OIDCCallbackHandler - return -} - -func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string) (*OIDCProvider, error) { provider, err := oidc.NewProvider(context.Background(), issuerURL) if err != nil { return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) @@ -54,30 +50,82 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string) (*OI oidcVerifier: provider.Verifier(&oidc.Config{ ClientID: clientID, }), + allowedUsers: allowedUsers, }, nil } -func NewOIDCProviderFromEnv(redirectURL string) (*OIDCProvider, error) { +// NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables. +func NewOIDCProviderFromEnv() (*OIDCProvider, error) { return NewOIDCProvider( common.OIDCIssuerURL, common.OIDCClientID, common.OIDCClientSecret, - redirectURL, + common.OIDCRedirectURL, + common.OIDCAllowedUsers, ) } -func (provider *OIDCProvider) SetOverrideHostEnabled(enabled bool) { - provider.overrideHost = enabled +func (auth *OIDCProvider) TokenCookieName() string { + return "godoxy_oidc_token" +} + +func (auth *OIDCProvider) SetOverrideHostEnabled(enabled bool) { + auth.overrideHost = enabled +} + +func (auth *OIDCProvider) SetAllowedUsers(users []string) { + auth.allowedUsers = users +} + +func (auth *OIDCProvider) CheckToken(w http.ResponseWriter, r *http.Request) error { + token, err := r.Cookie(auth.TokenCookieName()) + if err != nil { + return ErrMissingToken + } + + // checks for Expiry, Audience == ClientID, Issuer, etc. + idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value) + if err != nil { + return fmt.Errorf("failed to verify ID token: %w", err) + } + + if len(idToken.Audience) == 0 { + return ErrInvalidToken + } + + var claims struct { + Email string `json:"email"` + Username string `json:"preferred_username"` + } + if err := idToken.Claims(&claims); err != nil { + return fmt.Errorf("failed to parse claims: %w", err) + } + + if !slices.Contains(auth.allowedUsers, claims.Username) { + return ErrUserNotAllowed.Subject(claims.Username) + } + return nil +} + +// generateState generates a random string for ODIC state. +const odicStateLength = 32 + +func generateState() (string, error) { + b := make([]byte, odicStateLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b)[:odicStateLength], nil } // RedirectOIDC initiates the OIDC login flow. -func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Request) { - if provider == nil { - U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) +func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) { + state, err := generateState() + if err != nil { + U.HandleErr(w, r, err, http.StatusInternalServerError) return } - - state := common.GenerateRandomString(32) http.SetCookie(w, &http.Cookie{ Name: CookieOauthState, Value: state, @@ -88,8 +136,8 @@ func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Reques Path: "/", }) - redirURL := provider.oauthConfig.AuthCodeURL(state) - if provider.overrideHost { + redirURL := auth.oauthConfig.AuthCodeURL(state) + if auth.overrideHost { u, err := r.URL.Parse(redirURL) if err != nil { U.HandleErr(w, r, err, http.StatusInternalServerError) @@ -104,20 +152,10 @@ func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Reques } // OIDCCallbackHandler handles the OIDC callback. -func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { - if provider == nil { - U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) - return - } - +func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) { // For testing purposes, skip provider verification if common.IsTest { - handleTestCallback(w, r) - return - } - - if provider.oidcProvider == nil { - U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) + auth.handleTestCallback(w, r) return } @@ -133,7 +171,7 @@ func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http } code := r.URL.Query().Get("code") - oauth2Token, err := provider.oauthConfig.Exchange(r.Context(), code) + oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code) if err != nil { U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError) return @@ -145,32 +183,20 @@ func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http return } - idToken, err := provider.oidcVerifier.Verify(r.Context(), rawIDToken) + idToken, err := auth.oidcVerifier.Verify(r.Context(), rawIDToken) if err != nil { U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError) return } - var claims struct { - Email string `json:"email"` - Username string `json:"preferred_username"` - } - if err := idToken.Claims(&claims); err != nil { - U.HandleErr(w, r, fmt.Errorf("failed to parse claims: %w", err), http.StatusInternalServerError) - return - } - - if err := setAuthenticatedCookie(w, r, claims.Username); err != nil { - U.HandleErr(w, r, err, http.StatusInternalServerError) - return - } + setTokenCookie(w, r, auth.TokenCookieName(), rawIDToken, time.Until(idToken.Expiry)) // Redirect to home page http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } // handleTestCallback handles OIDC callback in test environment. -func handleTestCallback(w http.ResponseWriter, r *http.Request) { +func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) { state, err := r.Cookie(CookieOauthState) if err != nil { U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest) @@ -183,10 +209,7 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) { } // Create test JWT token - if err := setAuthenticatedCookie(w, r, "test-user"); err != nil { - U.HandleErr(w, r, err, http.StatusInternalServerError) - return - } + setTokenCookie(w, r, auth.TokenCookieName(), "test", time.Hour) http.Redirect(w, r, "/", http.StatusTemporaryRedirect) } diff --git a/internal/api/v1/auth/oidc_test.go b/internal/api/v1/auth/oidc_test.go index 76ed2d0..cd29acd 100644 --- a/internal/api/v1/auth/oidc_test.go +++ b/internal/api/v1/auth/oidc_test.go @@ -1,6 +1,7 @@ package auth import ( + "context" "net/http" "net/http/httptest" "testing" @@ -14,7 +15,8 @@ import ( func setupMockOIDC(t *testing.T) { t.Helper() - apiOAuth = &OIDCProvider{ + provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO()) + defaultAuth = &OIDCProvider{ oauthConfig: &oauth2.Config{ ClientID: "test-client", ClientSecret: "test-secret", @@ -25,53 +27,42 @@ func setupMockOIDC(t *testing.T) { }, Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, }, + oidcProvider: provider, + oidcVerifier: provider.Verifier(&oidc.Config{ + ClientID: "test-client", + }), + allowedUsers: []string{"test-user"}, } } func cleanup() { - apiOAuth = nil + defaultAuth = nil } func TestOIDCLoginHandler(t *testing.T) { // Setup common.APIJWTSecret = []byte("test-secret") - common.IsTest = true - t.Cleanup(func() { - cleanup() - common.IsTest = false - }) + t.Cleanup(cleanup) setupMockOIDC(t) tests := []struct { - name string - configureOAuth bool - wantStatus int - wantRedirect bool + name string + wantStatus int + wantRedirect bool }{ { - name: "Success - Redirects to provider", - configureOAuth: true, - wantStatus: http.StatusTemporaryRedirect, - wantRedirect: true, - }, - { - name: "Failure - OIDC not configured", - configureOAuth: false, - wantStatus: http.StatusNotImplemented, - wantRedirect: false, + name: "Success - Redirects to provider", + wantStatus: http.StatusTemporaryRedirect, + wantRedirect: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if !tt.configureOAuth { - apiOAuth = nil - } - req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil) w := httptest.NewRecorder() - apiOAuth.RedirectOIDC(w, req) + defaultAuth.RedirectLoginPage(w, req) if got := w.Code; got != tt.wantStatus { t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus) @@ -94,65 +85,45 @@ func TestOIDCLoginHandler(t *testing.T) { func TestOIDCCallbackHandler(t *testing.T) { // Setup common.APIJWTSecret = []byte("test-secret") - common.IsTest = true - t.Cleanup(func() { - cleanup() - common.IsTest = false - }) + t.Cleanup(cleanup) tests := []struct { - name string - configureOAuth bool - state string - code string - setupMocks func() - wantStatus int + name string + state string + code string + setupMocks bool + wantStatus int }{ { - name: "Success - Valid callback", - configureOAuth: true, - state: "valid-state", - code: "valid-code", - setupMocks: func() { - setupMockOIDC(t) - }, + name: "Success - Valid callback", + state: "valid-state", + code: "valid-code", + setupMocks: true, wantStatus: http.StatusTemporaryRedirect, }, { - name: "Failure - OIDC not configured", - configureOAuth: false, - wantStatus: http.StatusNotImplemented, - }, - { - name: "Failure - Missing state", - configureOAuth: true, - code: "valid-code", - setupMocks: func() { - setupMockOIDC(t) - }, + name: "Failure - Missing state", + code: "valid-code", + setupMocks: true, wantStatus: http.StatusBadRequest, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if tt.setupMocks != nil { - tt.setupMocks() - } - - if !tt.configureOAuth { - apiOAuth = nil + if tt.setupMocks { + setupMockOIDC(t) } req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil) if tt.state != "" { req.AddCookie(&http.Cookie{ - Name: "oauth_state", + Name: CookieOauthState, Value: tt.state, }) } w := httptest.NewRecorder() - apiOAuth.OIDCCallbackHandler(w, req) + defaultAuth.LoginCallbackHandler(w, req) if got := w.Code; got != tt.wantStatus { t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus) @@ -169,32 +140,48 @@ func TestOIDCCallbackHandler(t *testing.T) { } func TestInitOIDC(t *testing.T) { - common.IsTest = true - t.Cleanup(func() { - common.IsTest = false - }) tests := []struct { name string issuerURL string clientID string clientSecret string redirectURL string + allowedUsers []string wantErr bool }{ { - name: "Success - Empty configuration", + name: "Fail - Empty configuration", issuerURL: "", clientID: "", clientSecret: "", redirectURL: "", - wantErr: false, + allowedUsers: nil, + wantErr: true, + }, + // { + // name: "Success - Valid configuration", + // issuerURL: "https://example.com", + // clientID: "client_id", + // clientSecret: "client_secret", + // redirectURL: "https://example.com/callback", + // allowedUsers: []string{"user1", "user2"}, + // wantErr: false, + // }, + { + name: "Fail - No allowed users", + issuerURL: "https://example.com", + clientID: "client_id", + clientSecret: "client_secret", + redirectURL: "https://example.com/callback", + allowedUsers: []string{}, + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Cleanup(cleanup) - err := initOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL) + _, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers) if (err != nil) != tt.wantErr { t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/internal/api/v1/auth/provider.go b/internal/api/v1/auth/provider.go new file mode 100644 index 0000000..69f24b9 --- /dev/null +++ b/internal/api/v1/auth/provider.go @@ -0,0 +1,12 @@ +package auth + +import ( + "net/http" +) + +type Provider interface { + TokenCookieName() string + CheckToken(w http.ResponseWriter, r *http.Request) error + RedirectLoginPage(w http.ResponseWriter, r *http.Request) + LoginCallbackHandler(w http.ResponseWriter, r *http.Request) +} diff --git a/internal/api/v1/auth/userpass.go b/internal/api/v1/auth/userpass.go index 6e3c4c6..05b2887 100644 --- a/internal/api/v1/auth/userpass.go +++ b/internal/api/v1/auth/userpass.go @@ -1,13 +1,17 @@ package auth import ( - "bytes" "encoding/json" + "fmt" "net/http" + "time" + "github.com/golang-jwt/jwt/v5" U "github.com/yusing/go-proxy/internal/api/v1/utils" "github.com/yusing/go-proxy/internal/common" E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" + "golang.org/x/crypto/bcrypt" ) var ( @@ -15,31 +19,120 @@ var ( ErrInvalidPassword = E.New("invalid password") ) -func validatePassword(cred *Credentials) error { - if cred.Username != common.APIUser { - return ErrInvalidUsername.Subject(cred.Username) +type ( + UserPassAuth struct { + username string + pwdHash []byte + secret []byte + tokenTTL time.Duration } - if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) { - return ErrInvalidPassword.Subject(cred.Password) + UserPassClaims struct { + Username string `json:"username"` + jwt.RegisteredClaims } +) + +func NewUserPassAuth(username, password string, secret []byte, tokenTTL time.Duration) (*UserPassAuth, error) { + hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return nil, err + } + return &UserPassAuth{ + username: username, + pwdHash: hash, + secret: secret, + tokenTTL: tokenTTL, + }, nil +} + +func NewUserPassAuthFromEnv() (*UserPassAuth, error) { + return NewUserPassAuth( + common.APIUser, + common.APIPassword, + common.APIJWTSecret, + common.APIJWTTokenTTL, + ) +} + +func (auth *UserPassAuth) TokenCookieName() string { + return "godoxy_token" +} + +func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (token string, err error) { + claim := &UserPassClaims{ + Username: auth.username, + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(time.Now().Add(auth.tokenTTL)), + }, + } + tok := jwt.NewWithClaims(jwt.SigningMethodHS512, claim) + token, err = tok.SignedString(auth.secret) + if err != nil { + return "", err + } + return token, nil +} + +func (auth *UserPassAuth) CheckToken(w http.ResponseWriter, r *http.Request) error { + jwtCookie, err := r.Cookie(auth.TokenCookieName()) + if err != nil { + return ErrMissingToken + } + var claims UserPassClaims + token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"]) + } + return auth.secret, nil + }) + if err != nil { + return err + } + switch { + case !token.Valid: + return ErrInvalidToken + case claims.Username != auth.username: + return ErrUserNotAllowed.Subject(claims.Username) + case claims.ExpiresAt.Before(time.Now()): + return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time)) + } + return nil } -// UserPassLoginHandler handles user login. -func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) { - var creds Credentials +func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/login", http.StatusTemporaryRedirect) +} + +func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) { + var creds struct { + User string `json:"username"` + Pass string `json:"password"` + } err := json.NewDecoder(r.Body).Decode(&creds) if err != nil { U.HandleErr(w, r, err, http.StatusBadRequest) return } - if err := validatePassword(&creds); err != nil { + if err := auth.validatePassword(creds.User, creds.Pass); err != nil { U.HandleErr(w, r, err, http.StatusUnauthorized) return } - if err := setAuthenticatedCookie(w, r, creds.Username); err != nil { + token, err := auth.CreateToken(w, r) + if err != nil { U.HandleErr(w, r, err, http.StatusInternalServerError) return } + setTokenCookie(w, r, auth.TokenCookieName(), token, auth.tokenTTL) w.WriteHeader(http.StatusOK) } + +func (auth *UserPassAuth) validatePassword(user, pass string) error { + if user != auth.username { + return ErrInvalidUsername.Subject(user) + } + if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil { + return ErrInvalidPassword.Subject(pass) + } + return nil +} diff --git a/internal/api/v1/auth/utils.go b/internal/api/v1/auth/utils.go new file mode 100644 index 0000000..1d57de1 --- /dev/null +++ b/internal/api/v1/auth/utils.go @@ -0,0 +1,70 @@ +package auth + +import ( + "net" + "net/http" + "time" + + E "github.com/yusing/go-proxy/internal/error" + "github.com/yusing/go-proxy/internal/utils/strutils" +) + +var ( + ErrMissingToken = E.New("missing token") + ErrInvalidToken = E.New("invalid token") + ErrUserNotAllowed = E.New("user not allowed") +) + +// cookieFQDN returns the fully qualified domain name of the request host +// with subdomain stripped. +// +// If the request host does not have a subdomain, +// an empty string is returned +// +// "abc.example.com" -> "example.com" +// "example.com" -> "" +func cookieFQDN(r *http.Request) string { + host, _, err := net.SplitHostPort(r.Host) + if err != nil { + host = r.Host + } + parts := strutils.SplitRune(host, '.') + if len(parts) < 2 { + return "" + } + parts[0] = "" + return strutils.JoinRune(parts, '.') +} + +func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string, ttl time.Duration) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: value, + MaxAge: int(ttl.Seconds()), + Domain: cookieFQDN(r), + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + Path: "/", + }) +} + +func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) { + http.SetCookie(w, &http.Cookie{ + Name: name, + Value: "", + MaxAge: -1, + Domain: cookieFQDN(r), + HttpOnly: true, + Secure: true, + SameSite: http.SameSiteLaxMode, + Path: "/", + }) +} + +func LogoutCallbackHandler(auth Provider) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + clearTokenCookie(w, r, auth.TokenCookieName()) + auth.RedirectLoginPage(w, r) + } +} diff --git a/internal/api/v1/utils/logging.go b/internal/api/v1/utils/logging.go index ac795b8..194735f 100644 --- a/internal/api/v1/utils/logging.go +++ b/internal/api/v1/utils/logging.go @@ -11,6 +11,7 @@ func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event { return logging.WithLevel(level). Str("module", "api"). Str("remote", r.RemoteAddr). + Str("host", r.Host). Str("uri", r.Method+" "+r.RequestURI) } diff --git a/internal/common/crypto.go b/internal/common/crypto.go index 5afd0bf..6214a57 100644 --- a/internal/common/crypto.go +++ b/internal/common/crypto.go @@ -1,18 +1,11 @@ package common import ( - "crypto/sha512" "encoding/base64" "github.com/rs/zerolog/log" ) -func HashPassword(pwd string) []byte { - h := sha512.New() - h.Write([]byte(pwd)) - return h.Sum(nil) -} - func decodeJWTKey(key string) []byte { if key == "" { return nil diff --git a/internal/common/env.go b/internal/common/env.go index b8f9a09..3d207d7 100644 --- a/internal/common/env.go +++ b/internal/common/env.go @@ -9,6 +9,7 @@ import ( "time" "github.com/rs/zerolog/log" + "github.com/yusing/go-proxy/internal/utils/strutils" ) var ( @@ -40,10 +41,10 @@ var ( MetricsHTTPURL = GetAddrEnv("PROMETHEUS_ADDR", "", "http") PrometheusEnabled = MetricsHTTPURL != "" - APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", "")) - APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour) - APIUser = GetEnvString("API_USER", "admin") - APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password")) + APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", "")) + APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour) + APIUser = GetEnvString("API_USER", "admin") + APIPassword = GetEnvString("API_PASSWORD", "password") // OIDC Configuration. OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") @@ -51,6 +52,7 @@ var ( OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "") OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email") + OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "") ) func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T { @@ -102,3 +104,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str func GetDurationEnv(key string, defaultValue time.Duration) time.Duration { return GetEnv(key, defaultValue, time.ParseDuration) } + +func GetCommaSepEnv(key string, defaultValue string) []string { + return strutils.CommaSeperatedList(GetEnvString(key, defaultValue)) +} diff --git a/internal/common/random.go b/internal/common/random.go deleted file mode 100644 index ea4586f..0000000 --- a/internal/common/random.go +++ /dev/null @@ -1,13 +0,0 @@ -package common - -import ( - "crypto/rand" - "encoding/base64" -) - -// GenerateRandomString generates a random string of specified length. -func GenerateRandomString(length int) string { - b := make([]byte, length) - rand.Read(b) - return base64.URLEncoding.EncodeToString(b)[:length] -} diff --git a/internal/net/http/middleware/oidc.go b/internal/net/http/middleware/oidc.go index 5d23f25..968b6b0 100644 --- a/internal/net/http/middleware/oidc.go +++ b/internal/net/http/middleware/oidc.go @@ -8,43 +8,47 @@ import ( ) type oidcMiddleware struct { - oauth *auth.OIDCProvider - authMux *http.ServeMux + AllowedUsers []string + + auth auth.Provider + authMux *http.ServeMux + logoutHandler http.HandlerFunc } var OIDC = NewMiddleware[oidcMiddleware]() const ( - OIDCMiddlewareCallbackPath = "/godoxy-auth-oidc/callback" - OIDCLogoutPath = "/logout" + OIDCMiddlewareCallbackPath = "/auth/callback" + OIDCLogoutPath = "/auth/logout" ) func (amw *oidcMiddleware) finalize() error { if !auth.IsOIDCEnabled() { return E.New("OIDC not enabled but Auth middleware is used") } - provider, err := auth.NewOIDCProviderFromEnv(OIDCMiddlewareCallbackPath) + authProvider, err := auth.NewOIDCProviderFromEnv() if err != nil { return err } - provider.SetOverrideHostEnabled(true) - amw.oauth = provider + authProvider.SetOverrideHostEnabled(true) amw.authMux = http.NewServeMux() - amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, provider.OIDCCallbackHandler) + amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler) amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) { http.Error(w, "Unauthorized", http.StatusUnauthorized) }) - amw.authMux.HandleFunc("/", provider.RedirectOIDC) + amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage) + amw.logoutHandler = auth.LogoutCallbackHandler(authProvider) + amw.auth = authProvider return nil } func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { - if err, _ := auth.CheckToken(w, r); err != nil { + if err := amw.auth.CheckToken(w, r); err != nil { amw.authMux.ServeHTTP(w, r) return false } if r.URL.Path == OIDCLogoutPath { - auth.LogoutHandler(w, r) + amw.logoutHandler(w, r) return false } return true diff --git a/internal/utils/strutils/string.go b/internal/utils/strutils/string.go index 4664c2b..18f78c6 100644 --- a/internal/utils/strutils/string.go +++ b/internal/utils/strutils/string.go @@ -10,6 +10,9 @@ import ( // CommaSeperatedList returns a list of strings split by commas, // then trim spaces from each element. func CommaSeperatedList(s string) []string { + if s == "" { + return []string{} + } res := SplitComma(s) for i, part := range res { res[i] = strings.TrimSpace(part) diff --git a/next-release.md b/next-release.md index 515ec21..29c0660 100644 --- a/next-release.md +++ b/next-release.md @@ -75,13 +75,12 @@ GoDoxy v0.8.2 expected changes - **Thanks [polds](https://github.com/polds)** Support WebUI authentication via OIDC by setting these environment variables: - - `GODOXY_API_USER` - - `GODOXY_API_JWT_SECRET` - `GODOXY_OIDC_ISSUER_URL` - `GODOXY_OIDC_CLIENT_ID` - `GODOXY_OIDC_CLIENT_SECRET` - `GODOXY_OIDC_REDIRECT_URL` - - `GODOXY_OIDC_SCOPES` + - `GODOXY_OIDC_SCOPES` _(optional)_ + - `GODOXY_OIDC_ALLOWED_USERS` - Caddyfile like rules