auth code cleanup

This commit is contained in:
yusing 2025-01-14 04:05:33 +08:00
parent bb0ee5d7a9
commit 9aee310844
17 changed files with 394 additions and 350 deletions

View file

@ -1,20 +1,15 @@
# set timezone to get correct log timestamp # set timezone to get correct log timestamp
TZ=ETC/UTC TZ=ETC/UTC
# API/WebUI user password login credentials (optional)
# These fields are not required for OIDC authentication
GODOXY_API_USER=admin
GODOXY_API_PASSWORD=password
# generate secret with `openssl rand -base64 32` # generate secret with `openssl rand -base64 32`
# used for both user password authentication and OIDC
GODOXY_API_JWT_SECRET= GODOXY_API_JWT_SECRET=
# the JWT token time-to-live # the JWT token time-to-live
GODOXY_API_JWT_TOKEN_TTL=1h GODOXY_API_JWT_TOKEN_TTL=1h
# API/WebUI login credentials
# Important: If using OIDC authentication, the API_USER must match the username
# provided by the OIDC provider.
GODOXY_API_USER=admin
# Password is not required for OIDC authentication
GODOXY_API_PASSWORD=password
# OIDC Configuration (optional) # OIDC Configuration (optional)
# Uncomment and configure these values to enable OIDC authentication. # Uncomment and configure these values to enable OIDC authentication.
# GODOXY_OIDC_ISSUER_URL=https://accounts.google.com # GODOXY_OIDC_ISSUER_URL=https://accounts.google.com
@ -24,6 +19,8 @@ GODOXY_API_PASSWORD=password
# GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback # GODOXY_OIDC_REDIRECT_URL=https://your-domain/api/auth/callback
# Comma-separated list of scopes # Comma-separated list of scopes
# GODOXY_OIDC_SCOPES=openid, profile, email # GODOXY_OIDC_SCOPES=openid, profile, email
# Comma-separated list of allowed users
# GODOXY_OIDC_ALLOWED_USERS=user1,user2
# Proxy listening address # Proxy listening address
GODOXY_HTTP_ADDR=:80 GODOXY_HTTP_ADDR=:80

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/yusing/go-proxy/internal" "github.com/yusing/go-proxy/internal"
"github.com/yusing/go-proxy/internal/api/v1/auth"
"github.com/yusing/go-proxy/internal/api/v1/query" "github.com/yusing/go-proxy/internal/api/v1/query"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
"github.com/yusing/go-proxy/internal/config" "github.com/yusing/go-proxy/internal/config"
@ -108,6 +109,10 @@ func main() {
return return
} }
if err := auth.Initialize(); err != nil {
logging.Fatal().Err(err).Msg("failed to initialize authentication")
}
cfg.Start() cfg.Start()
config.WatchChanges() config.WatchChanges()

View file

@ -1,43 +1,44 @@
package api package api
import ( import (
"net"
"net/http" "net/http"
v1 "github.com/yusing/go-proxy/internal/api/v1" v1 "github.com/yusing/go-proxy/internal/api/v1"
"github.com/yusing/go-proxy/internal/api/v1/auth" "github.com/yusing/go-proxy/internal/api/v1/auth"
"github.com/yusing/go-proxy/internal/api/v1/favicon" "github.com/yusing/go-proxy/internal/api/v1/favicon"
. "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
config "github.com/yusing/go-proxy/internal/config/types" config "github.com/yusing/go-proxy/internal/config/types"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ServeMux struct{ *http.ServeMux } type ServeMux struct{ *http.ServeMux }
func (mux ServeMux) HandleFunc(method, endpoint string, handler http.HandlerFunc) { func (mux ServeMux) HandleFunc(methods, endpoint string, handler http.HandlerFunc) {
mux.ServeMux.HandleFunc(method+" "+endpoint, checkHost(handler)) for _, m := range strutils.CommaSeperatedList(methods) {
mux.ServeMux.HandleFunc(m+" "+endpoint, handler)
}
} }
func NewHandler(cfg config.ConfigInstance) http.Handler { func NewHandler(cfg config.ConfigInstance) http.Handler {
mux := ServeMux{http.NewServeMux()} mux := ServeMux{http.NewServeMux()}
mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler)
mux.HandleFunc("GET", "/v1/auth/redirect", auth.APIAuthRedirectHandler)
mux.HandleFunc("GET", "/v1/auth/callback", auth.APIOIDCCallbackHandler)
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload)) mux.HandleFunc("POST", "/v1/reload", useCfg(cfg, v1.Reload))
mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/list", auth.RequireAuth(useCfg(cfg, v1.List)))
mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/list/{what}", auth.RequireAuth(useCfg(cfg, v1.List)))
mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List))) mux.HandleFunc("GET", "/v1/list/{what}/{which}", auth.RequireAuth(useCfg(cfg, v1.List)))
mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent)) mux.HandleFunc("GET", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.GetFileContent))
mux.HandleFunc("POST", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent)) mux.HandleFunc("POST,PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
mux.HandleFunc("PUT", "/v1/file/{type}/{filename}", auth.RequireAuth(v1.SetFileContent))
mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile) mux.HandleFunc("GET", "/v1/schema/{filename...}", v1.GetSchemaFile)
mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats)) mux.HandleFunc("GET", "/v1/stats", useCfg(cfg, v1.Stats))
mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS)) mux.HandleFunc("GET", "/v1/stats/ws", useCfg(cfg, v1.StatsWS))
mux.HandleFunc("GET", "/v1/favicon/{alias}", auth.RequireAuth(favicon.GetFavIcon)) mux.HandleFunc("GET", "/v1/favicon/{alias}", auth.RequireAuth(favicon.GetFavIcon))
defaultAuth := auth.GetDefaultAuth()
if defaultAuth != nil {
mux.HandleFunc("GET", "/v1/auth/redirect", defaultAuth.RedirectLoginPage)
mux.HandleFunc("GET,POST", "/v1/auth/callback", defaultAuth.LoginCallbackHandler)
mux.HandleFunc("GET,POST", "/v1/auth/logout", auth.LogoutCallbackHandler(defaultAuth))
}
return mux return mux
} }
@ -46,20 +47,3 @@ func useCfg(cfg config.ConfigInstance, handler func(cfg config.ConfigInstance, w
handler(cfg, w, r) handler(cfg, w, r)
} }
} }
// allow only requests to API server with localhost.
func checkHost(f http.HandlerFunc) http.HandlerFunc {
if common.IsDebug {
return f
}
return func(w http.ResponseWriter, r *http.Request) {
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if host != "127.0.0.1" && host != "localhost" && host != "[::1]" {
LogWarn(r).Msgf("blocked API request from %s", host)
http.Error(w, "forbidden", http.StatusForbidden)
return
}
LogDebug(r).Interface("headers", r.Header).Msg("API request")
f(w, r)
}
}

View file

@ -1,47 +1,35 @@
package auth package auth
import ( import (
"fmt"
"net"
"net/http" "net/http"
"time"
"github.com/golang-jwt/jwt/v5"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/logging" "github.com/yusing/go-proxy/internal/logging"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
type ( var defaultAuth Provider
Credentials struct {
Username string `json:"username"`
Password string `json:"password"`
}
Claims struct {
Username string `json:"username"`
jwt.RegisteredClaims
}
)
// init sets up authentication providers. // Initialize sets up authentication providers.
func init() { func Initialize() error {
if !IsEnabled() { if !IsEnabled() {
logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication") logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
return return nil
} }
var err error
// Initialize OIDC if configured. // Initialize OIDC if configured.
if common.OIDCIssuerURL != "" { if common.OIDCIssuerURL != "" {
if err := initOIDC( defaultAuth, err = NewOIDCProviderFromEnv()
common.OIDCIssuerURL, } else {
common.OIDCClientID, defaultAuth, err = NewUserPassAuthFromEnv()
common.OIDCClientSecret,
common.OIDCRedirectURL,
); err != nil {
logging.Fatal().Err(err).Msg("failed to initialize OIDC provider")
}
} }
return err
}
func GetDefaultAuth() Provider {
return defaultAuth
} }
func IsEnabled() bool { func IsEnabled() bool {
@ -52,85 +40,10 @@ func IsOIDCEnabled() bool {
return common.OIDCIssuerURL != "" return common.OIDCIssuerURL != ""
} }
// cookieFQDN returns the fully qualified domain name of the request host
// with subdomain stripped.
//
// If the request host does not have a subdomain,
// an empty string is returned
//
// "abc.example.com" -> "example.com"
// "example.com" -> ""
func cookieFQDN(r *http.Request) string {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
parts := strutils.SplitRune(host, '.')
if len(parts) < 2 {
return ""
}
parts[0] = ""
return strutils.JoinRune(parts, '.')
}
// APIAuthRedirectHandler handles API redirect to login page or OIDC login base on configuration.
func APIAuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
switch {
case apiOAuth != nil:
apiOAuth.RedirectOIDC(w, r)
return
case common.APIJWTSecret != nil:
http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
return
default:
http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
}
}
func setAuthenticatedCookie(w http.ResponseWriter, r *http.Request, username string) error {
expiresAt := time.Now().Add(common.APIJWTTokenTTL)
claim := &Claims{
Username: username,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS512, claim)
tokenStr, err := token.SignedString(common.APIJWTSecret)
if err != nil {
return err
}
http.SetCookie(w, &http.Cookie{
Name: CookieToken,
Value: tokenStr,
Expires: expiresAt,
Domain: cookieFQDN(r),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
Path: "/",
})
return nil
}
// LogoutHandler clear authentication cookie and redirect to login page.
func LogoutHandler(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: CookieToken,
MaxAge: -1,
Domain: cookieFQDN(r),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
Path: "/",
})
APIAuthRedirectHandler(w, r)
}
func RequireAuth(next http.HandlerFunc) http.HandlerFunc { func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
if IsEnabled() { if IsEnabled() {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if err := CheckToken(w, r); err != nil { if err := defaultAuth.CheckToken(w, r); err != nil {
U.RespondError(w, err, http.StatusUnauthorized) U.RespondError(w, err, http.StatusUnauthorized)
} else { } else {
next(w, r) next(w, r)
@ -139,30 +52,3 @@ func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
} }
return next return next
} }
func CheckToken(w http.ResponseWriter, r *http.Request) error {
tokenCookie, err := r.Cookie(CookieToken)
if err != nil {
return E.New("missing token")
}
var claims Claims
token, err := jwt.ParseWithClaims(tokenCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return common.APIJWTSecret, nil
})
if err != nil {
return err
}
switch {
case !token.Valid:
return E.New("invalid token")
case claims.Username != common.APIUser:
return E.New("username mismatch").Subject(claims.Username)
case claims.ExpiresAt.Before(time.Now()):
return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
}
return nil
}

View file

@ -1,6 +0,0 @@
package auth
const (
CookieToken = "godoxy_token"
CookieOauthState = "godoxy_oauth_state"
)

View file

@ -2,8 +2,13 @@ package auth
import ( import (
"context" "context"
"crypto/rand"
"encoding/base64"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"slices"
"time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
@ -17,26 +22,17 @@ type OIDCProvider struct {
oauthConfig *oauth2.Config oauthConfig *oauth2.Config
oidcProvider *oidc.Provider oidcProvider *oidc.Provider
oidcVerifier *oidc.IDTokenVerifier oidcVerifier *oidc.IDTokenVerifier
allowedUsers []string
overrideHost bool overrideHost bool
} }
var ( const CookieOauthState = "godoxy_oidc_state"
apiOAuth *OIDCProvider
APIOIDCCallbackHandler http.HandlerFunc
)
// initOIDC initializes the OIDC provider. func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string, allowedUsers []string) (*OIDCProvider, error) {
func initOIDC(issuerURL, clientID, clientSecret, redirectURL string) (err error) { if len(allowedUsers) == 0 {
if issuerURL == "" { return nil, errors.New("OIDC allowed users must not be empty")
return nil // OIDC not configured
} }
apiOAuth, err = NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL)
APIOIDCCallbackHandler = apiOAuth.OIDCCallbackHandler
return
}
func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string) (*OIDCProvider, error) {
provider, err := oidc.NewProvider(context.Background(), issuerURL) provider, err := oidc.NewProvider(context.Background(), issuerURL)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err) return nil, fmt.Errorf("failed to initialize OIDC provider: %w", err)
@ -54,30 +50,82 @@ func NewOIDCProvider(issuerURL, clientID, clientSecret, redirectURL string) (*OI
oidcVerifier: provider.Verifier(&oidc.Config{ oidcVerifier: provider.Verifier(&oidc.Config{
ClientID: clientID, ClientID: clientID,
}), }),
allowedUsers: allowedUsers,
}, nil }, nil
} }
func NewOIDCProviderFromEnv(redirectURL string) (*OIDCProvider, error) { // NewOIDCProviderFromEnv creates a new OIDCProvider from environment variables.
func NewOIDCProviderFromEnv() (*OIDCProvider, error) {
return NewOIDCProvider( return NewOIDCProvider(
common.OIDCIssuerURL, common.OIDCIssuerURL,
common.OIDCClientID, common.OIDCClientID,
common.OIDCClientSecret, common.OIDCClientSecret,
redirectURL, common.OIDCRedirectURL,
common.OIDCAllowedUsers,
) )
} }
func (provider *OIDCProvider) SetOverrideHostEnabled(enabled bool) { func (auth *OIDCProvider) TokenCookieName() string {
provider.overrideHost = enabled return "godoxy_oidc_token"
}
func (auth *OIDCProvider) SetOverrideHostEnabled(enabled bool) {
auth.overrideHost = enabled
}
func (auth *OIDCProvider) SetAllowedUsers(users []string) {
auth.allowedUsers = users
}
func (auth *OIDCProvider) CheckToken(w http.ResponseWriter, r *http.Request) error {
token, err := r.Cookie(auth.TokenCookieName())
if err != nil {
return ErrMissingToken
}
// checks for Expiry, Audience == ClientID, Issuer, etc.
idToken, err := auth.oidcVerifier.Verify(r.Context(), token.Value)
if err != nil {
return fmt.Errorf("failed to verify ID token: %w", err)
}
if len(idToken.Audience) == 0 {
return ErrInvalidToken
}
var claims struct {
Email string `json:"email"`
Username string `json:"preferred_username"`
}
if err := idToken.Claims(&claims); err != nil {
return fmt.Errorf("failed to parse claims: %w", err)
}
if !slices.Contains(auth.allowedUsers, claims.Username) {
return ErrUserNotAllowed.Subject(claims.Username)
}
return nil
}
// generateState generates a random string for ODIC state.
const odicStateLength = 32
func generateState() (string, error) {
b := make([]byte, odicStateLength)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b)[:odicStateLength], nil
} }
// RedirectOIDC initiates the OIDC login flow. // RedirectOIDC initiates the OIDC login flow.
func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Request) { func (auth *OIDCProvider) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
if provider == nil { state, err := generateState()
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) if err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
return return
} }
state := common.GenerateRandomString(32)
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: CookieOauthState, Name: CookieOauthState,
Value: state, Value: state,
@ -88,8 +136,8 @@ func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Reques
Path: "/", Path: "/",
}) })
redirURL := provider.oauthConfig.AuthCodeURL(state) redirURL := auth.oauthConfig.AuthCodeURL(state)
if provider.overrideHost { if auth.overrideHost {
u, err := r.URL.Parse(redirURL) u, err := r.URL.Parse(redirURL)
if err != nil { if err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError) U.HandleErr(w, r, err, http.StatusInternalServerError)
@ -104,20 +152,10 @@ func (provider *OIDCProvider) RedirectOIDC(w http.ResponseWriter, r *http.Reques
} }
// OIDCCallbackHandler handles the OIDC callback. // OIDCCallbackHandler handles the OIDC callback.
func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) { func (auth *OIDCProvider) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
if provider == nil {
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
return
}
// For testing purposes, skip provider verification // For testing purposes, skip provider verification
if common.IsTest { if common.IsTest {
handleTestCallback(w, r) auth.handleTestCallback(w, r)
return
}
if provider.oidcProvider == nil {
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
return return
} }
@ -133,7 +171,7 @@ func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http
} }
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
oauth2Token, err := provider.oauthConfig.Exchange(r.Context(), code) oauth2Token, err := auth.oauthConfig.Exchange(r.Context(), code)
if err != nil { if err != nil {
U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError) U.HandleErr(w, r, fmt.Errorf("failed to exchange token: %w", err), http.StatusInternalServerError)
return return
@ -145,32 +183,20 @@ func (provider *OIDCProvider) OIDCCallbackHandler(w http.ResponseWriter, r *http
return return
} }
idToken, err := provider.oidcVerifier.Verify(r.Context(), rawIDToken) idToken, err := auth.oidcVerifier.Verify(r.Context(), rawIDToken)
if err != nil { if err != nil {
U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError) U.HandleErr(w, r, fmt.Errorf("failed to verify ID token: %w", err), http.StatusInternalServerError)
return return
} }
var claims struct { setTokenCookie(w, r, auth.TokenCookieName(), rawIDToken, time.Until(idToken.Expiry))
Email string `json:"email"`
Username string `json:"preferred_username"`
}
if err := idToken.Claims(&claims); err != nil {
U.HandleErr(w, r, fmt.Errorf("failed to parse claims: %w", err), http.StatusInternalServerError)
return
}
if err := setAuthenticatedCookie(w, r, claims.Username); err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
return
}
// Redirect to home page // Redirect to home page
http.Redirect(w, r, "/", http.StatusTemporaryRedirect) http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
} }
// handleTestCallback handles OIDC callback in test environment. // handleTestCallback handles OIDC callback in test environment.
func handleTestCallback(w http.ResponseWriter, r *http.Request) { func (auth *OIDCProvider) handleTestCallback(w http.ResponseWriter, r *http.Request) {
state, err := r.Cookie(CookieOauthState) state, err := r.Cookie(CookieOauthState)
if err != nil { if err != nil {
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest) U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
@ -183,10 +209,7 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) {
} }
// Create test JWT token // Create test JWT token
if err := setAuthenticatedCookie(w, r, "test-user"); err != nil { setTokenCookie(w, r, auth.TokenCookieName(), "test", time.Hour)
U.HandleErr(w, r, err, http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusTemporaryRedirect) http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
} }

View file

@ -1,6 +1,7 @@
package auth package auth
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -14,7 +15,8 @@ import (
func setupMockOIDC(t *testing.T) { func setupMockOIDC(t *testing.T) {
t.Helper() t.Helper()
apiOAuth = &OIDCProvider{ provider := (&oidc.ProviderConfig{}).NewProvider(context.TODO())
defaultAuth = &OIDCProvider{
oauthConfig: &oauth2.Config{ oauthConfig: &oauth2.Config{
ClientID: "test-client", ClientID: "test-client",
ClientSecret: "test-secret", ClientSecret: "test-secret",
@ -25,53 +27,42 @@ func setupMockOIDC(t *testing.T) {
}, },
Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}, },
oidcProvider: provider,
oidcVerifier: provider.Verifier(&oidc.Config{
ClientID: "test-client",
}),
allowedUsers: []string{"test-user"},
} }
} }
func cleanup() { func cleanup() {
apiOAuth = nil defaultAuth = nil
} }
func TestOIDCLoginHandler(t *testing.T) { func TestOIDCLoginHandler(t *testing.T) {
// Setup // Setup
common.APIJWTSecret = []byte("test-secret") common.APIJWTSecret = []byte("test-secret")
common.IsTest = true t.Cleanup(cleanup)
t.Cleanup(func() {
cleanup()
common.IsTest = false
})
setupMockOIDC(t) setupMockOIDC(t)
tests := []struct { tests := []struct {
name string name string
configureOAuth bool wantStatus int
wantStatus int wantRedirect bool
wantRedirect bool
}{ }{
{ {
name: "Success - Redirects to provider", name: "Success - Redirects to provider",
configureOAuth: true, wantStatus: http.StatusTemporaryRedirect,
wantStatus: http.StatusTemporaryRedirect, wantRedirect: true,
wantRedirect: true,
},
{
name: "Failure - OIDC not configured",
configureOAuth: false,
wantStatus: http.StatusNotImplemented,
wantRedirect: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if !tt.configureOAuth {
apiOAuth = nil
}
req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil) req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
apiOAuth.RedirectOIDC(w, req) defaultAuth.RedirectLoginPage(w, req)
if got := w.Code; got != tt.wantStatus { if got := w.Code; got != tt.wantStatus {
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus) t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)
@ -94,65 +85,45 @@ func TestOIDCLoginHandler(t *testing.T) {
func TestOIDCCallbackHandler(t *testing.T) { func TestOIDCCallbackHandler(t *testing.T) {
// Setup // Setup
common.APIJWTSecret = []byte("test-secret") common.APIJWTSecret = []byte("test-secret")
common.IsTest = true t.Cleanup(cleanup)
t.Cleanup(func() {
cleanup()
common.IsTest = false
})
tests := []struct { tests := []struct {
name string name string
configureOAuth bool state string
state string code string
code string setupMocks bool
setupMocks func() wantStatus int
wantStatus int
}{ }{
{ {
name: "Success - Valid callback", name: "Success - Valid callback",
configureOAuth: true, state: "valid-state",
state: "valid-state", code: "valid-code",
code: "valid-code", setupMocks: true,
setupMocks: func() {
setupMockOIDC(t)
},
wantStatus: http.StatusTemporaryRedirect, wantStatus: http.StatusTemporaryRedirect,
}, },
{ {
name: "Failure - OIDC not configured", name: "Failure - Missing state",
configureOAuth: false, code: "valid-code",
wantStatus: http.StatusNotImplemented, setupMocks: true,
},
{
name: "Failure - Missing state",
configureOAuth: true,
code: "valid-code",
setupMocks: func() {
setupMockOIDC(t)
},
wantStatus: http.StatusBadRequest, wantStatus: http.StatusBadRequest,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if tt.setupMocks != nil { if tt.setupMocks {
tt.setupMocks() setupMockOIDC(t)
}
if !tt.configureOAuth {
apiOAuth = nil
} }
req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil) req := httptest.NewRequest(http.MethodGet, "/auth/callback?code="+tt.code+"&state="+tt.state, nil)
if tt.state != "" { if tt.state != "" {
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
Name: "oauth_state", Name: CookieOauthState,
Value: tt.state, Value: tt.state,
}) })
} }
w := httptest.NewRecorder() w := httptest.NewRecorder()
apiOAuth.OIDCCallbackHandler(w, req) defaultAuth.LoginCallbackHandler(w, req)
if got := w.Code; got != tt.wantStatus { if got := w.Code; got != tt.wantStatus {
t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus) t.Errorf("OIDCCallbackHandler() status = %v, want %v", got, tt.wantStatus)
@ -169,32 +140,48 @@ func TestOIDCCallbackHandler(t *testing.T) {
} }
func TestInitOIDC(t *testing.T) { func TestInitOIDC(t *testing.T) {
common.IsTest = true
t.Cleanup(func() {
common.IsTest = false
})
tests := []struct { tests := []struct {
name string name string
issuerURL string issuerURL string
clientID string clientID string
clientSecret string clientSecret string
redirectURL string redirectURL string
allowedUsers []string
wantErr bool wantErr bool
}{ }{
{ {
name: "Success - Empty configuration", name: "Fail - Empty configuration",
issuerURL: "", issuerURL: "",
clientID: "", clientID: "",
clientSecret: "", clientSecret: "",
redirectURL: "", redirectURL: "",
wantErr: false, allowedUsers: nil,
wantErr: true,
},
// {
// name: "Success - Valid configuration",
// issuerURL: "https://example.com",
// clientID: "client_id",
// clientSecret: "client_secret",
// redirectURL: "https://example.com/callback",
// allowedUsers: []string{"user1", "user2"},
// wantErr: false,
// },
{
name: "Fail - No allowed users",
issuerURL: "https://example.com",
clientID: "client_id",
clientSecret: "client_secret",
redirectURL: "https://example.com/callback",
allowedUsers: []string{},
wantErr: true,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Cleanup(cleanup) t.Cleanup(cleanup)
err := initOIDC(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL) _, err := NewOIDCProvider(tt.issuerURL, tt.clientID, tt.clientSecret, tt.redirectURL, tt.allowedUsers)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("InitOIDC() error = %v, wantErr %v", err, tt.wantErr)
} }

View file

@ -0,0 +1,12 @@
package auth
import (
"net/http"
)
type Provider interface {
TokenCookieName() string
CheckToken(w http.ResponseWriter, r *http.Request) error
RedirectLoginPage(w http.ResponseWriter, r *http.Request)
LoginCallbackHandler(w http.ResponseWriter, r *http.Request)
}

View file

@ -1,13 +1,17 @@
package auth package auth
import ( import (
"bytes"
"encoding/json" "encoding/json"
"fmt"
"net/http" "net/http"
"time"
"github.com/golang-jwt/jwt/v5"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils"
"golang.org/x/crypto/bcrypt"
) )
var ( var (
@ -15,31 +19,120 @@ var (
ErrInvalidPassword = E.New("invalid password") ErrInvalidPassword = E.New("invalid password")
) )
func validatePassword(cred *Credentials) error { type (
if cred.Username != common.APIUser { UserPassAuth struct {
return ErrInvalidUsername.Subject(cred.Username) username string
pwdHash []byte
secret []byte
tokenTTL time.Duration
} }
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) { UserPassClaims struct {
return ErrInvalidPassword.Subject(cred.Password) Username string `json:"username"`
jwt.RegisteredClaims
} }
)
func NewUserPassAuth(username, password string, secret []byte, tokenTTL time.Duration) (*UserPassAuth, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
return &UserPassAuth{
username: username,
pwdHash: hash,
secret: secret,
tokenTTL: tokenTTL,
}, nil
}
func NewUserPassAuthFromEnv() (*UserPassAuth, error) {
return NewUserPassAuth(
common.APIUser,
common.APIPassword,
common.APIJWTSecret,
common.APIJWTTokenTTL,
)
}
func (auth *UserPassAuth) TokenCookieName() string {
return "godoxy_token"
}
func (auth *UserPassAuth) CreateToken(w http.ResponseWriter, r *http.Request) (token string, err error) {
claim := &UserPassClaims{
Username: auth.username,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(auth.tokenTTL)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS512, claim)
token, err = tok.SignedString(auth.secret)
if err != nil {
return "", err
}
return token, nil
}
func (auth *UserPassAuth) CheckToken(w http.ResponseWriter, r *http.Request) error {
jwtCookie, err := r.Cookie(auth.TokenCookieName())
if err != nil {
return ErrMissingToken
}
var claims UserPassClaims
token, err := jwt.ParseWithClaims(jwtCookie.Value, &claims, func(t *jwt.Token) (interface{}, error) {
if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", t.Header["alg"])
}
return auth.secret, nil
})
if err != nil {
return err
}
switch {
case !token.Valid:
return ErrInvalidToken
case claims.Username != auth.username:
return ErrUserNotAllowed.Subject(claims.Username)
case claims.ExpiresAt.Before(time.Now()):
return E.Errorf("token expired on %s", strutils.FormatTime(claims.ExpiresAt.Time))
}
return nil return nil
} }
// UserPassLoginHandler handles user login. func (auth *UserPassAuth) RedirectLoginPage(w http.ResponseWriter, r *http.Request) {
func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
var creds Credentials }
func (auth *UserPassAuth) LoginCallbackHandler(w http.ResponseWriter, r *http.Request) {
var creds struct {
User string `json:"username"`
Pass string `json:"password"`
}
err := json.NewDecoder(r.Body).Decode(&creds) err := json.NewDecoder(r.Body).Decode(&creds)
if err != nil { if err != nil {
U.HandleErr(w, r, err, http.StatusBadRequest) U.HandleErr(w, r, err, http.StatusBadRequest)
return return
} }
if err := validatePassword(&creds); err != nil { if err := auth.validatePassword(creds.User, creds.Pass); err != nil {
U.HandleErr(w, r, err, http.StatusUnauthorized) U.HandleErr(w, r, err, http.StatusUnauthorized)
return return
} }
if err := setAuthenticatedCookie(w, r, creds.Username); err != nil { token, err := auth.CreateToken(w, r)
if err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError) U.HandleErr(w, r, err, http.StatusInternalServerError)
return return
} }
setTokenCookie(w, r, auth.TokenCookieName(), token, auth.tokenTTL)
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
func (auth *UserPassAuth) validatePassword(user, pass string) error {
if user != auth.username {
return ErrInvalidUsername.Subject(user)
}
if err := bcrypt.CompareHashAndPassword(auth.pwdHash, []byte(pass)); err != nil {
return ErrInvalidPassword.Subject(pass)
}
return nil
}

View file

@ -0,0 +1,70 @@
package auth
import (
"net"
"net/http"
"time"
E "github.com/yusing/go-proxy/internal/error"
"github.com/yusing/go-proxy/internal/utils/strutils"
)
var (
ErrMissingToken = E.New("missing token")
ErrInvalidToken = E.New("invalid token")
ErrUserNotAllowed = E.New("user not allowed")
)
// cookieFQDN returns the fully qualified domain name of the request host
// with subdomain stripped.
//
// If the request host does not have a subdomain,
// an empty string is returned
//
// "abc.example.com" -> "example.com"
// "example.com" -> ""
func cookieFQDN(r *http.Request) string {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
parts := strutils.SplitRune(host, '.')
if len(parts) < 2 {
return ""
}
parts[0] = ""
return strutils.JoinRune(parts, '.')
}
func setTokenCookie(w http.ResponseWriter, r *http.Request, name, value string, ttl time.Duration) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: value,
MaxAge: int(ttl.Seconds()),
Domain: cookieFQDN(r),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
Path: "/",
})
}
func clearTokenCookie(w http.ResponseWriter, r *http.Request, name string) {
http.SetCookie(w, &http.Cookie{
Name: name,
Value: "",
MaxAge: -1,
Domain: cookieFQDN(r),
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
Path: "/",
})
}
func LogoutCallbackHandler(auth Provider) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
clearTokenCookie(w, r, auth.TokenCookieName())
auth.RedirectLoginPage(w, r)
}
}

View file

@ -11,6 +11,7 @@ func reqLogger(r *http.Request, level zerolog.Level) *zerolog.Event {
return logging.WithLevel(level). return logging.WithLevel(level).
Str("module", "api"). Str("module", "api").
Str("remote", r.RemoteAddr). Str("remote", r.RemoteAddr).
Str("host", r.Host).
Str("uri", r.Method+" "+r.RequestURI) Str("uri", r.Method+" "+r.RequestURI)
} }

View file

@ -1,18 +1,11 @@
package common package common
import ( import (
"crypto/sha512"
"encoding/base64" "encoding/base64"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func HashPassword(pwd string) []byte {
h := sha512.New()
h.Write([]byte(pwd))
return h.Sum(nil)
}
func decodeJWTKey(key string) []byte { func decodeJWTKey(key string) []byte {
if key == "" { if key == "" {
return nil return nil

View file

@ -9,6 +9,7 @@ import (
"time" "time"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/yusing/go-proxy/internal/utils/strutils"
) )
var ( var (
@ -40,10 +41,10 @@ var (
MetricsHTTPURL = GetAddrEnv("PROMETHEUS_ADDR", "", "http") MetricsHTTPURL = GetAddrEnv("PROMETHEUS_ADDR", "", "http")
PrometheusEnabled = MetricsHTTPURL != "" PrometheusEnabled = MetricsHTTPURL != ""
APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", "")) APIJWTSecret = decodeJWTKey(GetEnvString("API_JWT_SECRET", ""))
APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour) APIJWTTokenTTL = GetDurationEnv("API_JWT_TOKEN_TTL", time.Hour)
APIUser = GetEnvString("API_USER", "admin") APIUser = GetEnvString("API_USER", "admin")
APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password")) APIPassword = GetEnvString("API_PASSWORD", "password")
// OIDC Configuration. // OIDC Configuration.
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
@ -51,6 +52,7 @@ var (
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")
OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "") OIDCRedirectURL = GetEnvString("OIDC_REDIRECT_URL", "")
OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email") OIDCScopes = GetEnvString("OIDC_SCOPES", "openid, profile, email")
OIDCAllowedUsers = GetCommaSepEnv("OIDC_ALLOWED_USERS", "")
) )
func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T { func GetEnv[T any](key string, defaultValue T, parser func(string) (T, error)) T {
@ -102,3 +104,7 @@ func GetAddrEnv(key, defaultValue, scheme string) (addr, host, port, fullURL str
func GetDurationEnv(key string, defaultValue time.Duration) time.Duration { func GetDurationEnv(key string, defaultValue time.Duration) time.Duration {
return GetEnv(key, defaultValue, time.ParseDuration) return GetEnv(key, defaultValue, time.ParseDuration)
} }
func GetCommaSepEnv(key string, defaultValue string) []string {
return strutils.CommaSeperatedList(GetEnvString(key, defaultValue))
}

View file

@ -1,13 +0,0 @@
package common
import (
"crypto/rand"
"encoding/base64"
)
// GenerateRandomString generates a random string of specified length.
func GenerateRandomString(length int) string {
b := make([]byte, length)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)[:length]
}

View file

@ -8,43 +8,47 @@ import (
) )
type oidcMiddleware struct { type oidcMiddleware struct {
oauth *auth.OIDCProvider AllowedUsers []string
authMux *http.ServeMux
auth auth.Provider
authMux *http.ServeMux
logoutHandler http.HandlerFunc
} }
var OIDC = NewMiddleware[oidcMiddleware]() var OIDC = NewMiddleware[oidcMiddleware]()
const ( const (
OIDCMiddlewareCallbackPath = "/godoxy-auth-oidc/callback" OIDCMiddlewareCallbackPath = "/auth/callback"
OIDCLogoutPath = "/logout" OIDCLogoutPath = "/auth/logout"
) )
func (amw *oidcMiddleware) finalize() error { func (amw *oidcMiddleware) finalize() error {
if !auth.IsOIDCEnabled() { if !auth.IsOIDCEnabled() {
return E.New("OIDC not enabled but Auth middleware is used") return E.New("OIDC not enabled but Auth middleware is used")
} }
provider, err := auth.NewOIDCProviderFromEnv(OIDCMiddlewareCallbackPath) authProvider, err := auth.NewOIDCProviderFromEnv()
if err != nil { if err != nil {
return err return err
} }
provider.SetOverrideHostEnabled(true) authProvider.SetOverrideHostEnabled(true)
amw.oauth = provider
amw.authMux = http.NewServeMux() amw.authMux = http.NewServeMux()
amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, provider.OIDCCallbackHandler) amw.authMux.HandleFunc(OIDCMiddlewareCallbackPath, authProvider.LoginCallbackHandler)
amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) { amw.authMux.HandleFunc(OIDCLogoutPath, func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Unauthorized", http.StatusUnauthorized) http.Error(w, "Unauthorized", http.StatusUnauthorized)
}) })
amw.authMux.HandleFunc("/", provider.RedirectOIDC) amw.authMux.HandleFunc("/", authProvider.RedirectLoginPage)
amw.logoutHandler = auth.LogoutCallbackHandler(authProvider)
amw.auth = authProvider
return nil return nil
} }
func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) { func (amw *oidcMiddleware) before(w http.ResponseWriter, r *http.Request) (proceed bool) {
if err, _ := auth.CheckToken(w, r); err != nil { if err := amw.auth.CheckToken(w, r); err != nil {
amw.authMux.ServeHTTP(w, r) amw.authMux.ServeHTTP(w, r)
return false return false
} }
if r.URL.Path == OIDCLogoutPath { if r.URL.Path == OIDCLogoutPath {
auth.LogoutHandler(w, r) amw.logoutHandler(w, r)
return false return false
} }
return true return true

View file

@ -10,6 +10,9 @@ import (
// CommaSeperatedList returns a list of strings split by commas, // CommaSeperatedList returns a list of strings split by commas,
// then trim spaces from each element. // then trim spaces from each element.
func CommaSeperatedList(s string) []string { func CommaSeperatedList(s string) []string {
if s == "" {
return []string{}
}
res := SplitComma(s) res := SplitComma(s)
for i, part := range res { for i, part := range res {
res[i] = strings.TrimSpace(part) res[i] = strings.TrimSpace(part)

View file

@ -75,13 +75,12 @@ GoDoxy v0.8.2 expected changes
- **Thanks [polds](https://github.com/polds)** - **Thanks [polds](https://github.com/polds)**
Support WebUI authentication via OIDC by setting these environment variables: Support WebUI authentication via OIDC by setting these environment variables:
- `GODOXY_API_USER`
- `GODOXY_API_JWT_SECRET`
- `GODOXY_OIDC_ISSUER_URL` - `GODOXY_OIDC_ISSUER_URL`
- `GODOXY_OIDC_CLIENT_ID` - `GODOXY_OIDC_CLIENT_ID`
- `GODOXY_OIDC_CLIENT_SECRET` - `GODOXY_OIDC_CLIENT_SECRET`
- `GODOXY_OIDC_REDIRECT_URL` - `GODOXY_OIDC_REDIRECT_URL`
- `GODOXY_OIDC_SCOPES` - `GODOXY_OIDC_SCOPES` _(optional)_
- `GODOXY_OIDC_ALLOWED_USERS`
- Caddyfile like rules - Caddyfile like rules