cleanup code, redirect to auth page when need

This commit is contained in:
yusing 2025-01-13 07:15:29 +08:00
parent ef277ef57f
commit 76fe5345d8
11 changed files with 113 additions and 109 deletions

View file

@ -109,16 +109,16 @@ func main() {
return return
} }
if common.APIJWTSecret == nil {
logging.Warn().Msg("API JWT secret is empty, authentication is disabled")
}
cfg.Start() cfg.Start()
config.WatchChanges() config.WatchChanges()
// Initialize authentication providers if !auth.IsEnabled() {
if err := auth.Initialize(); err != nil { logging.Warn().Msg("authentication is disabled, please set API_JWT_SECRET or OIDC_* to enable authentication")
logging.Warn().Err(err).Msg("Failed to initialize authentication providers") } else {
// Initialize authentication providers
if err := auth.Initialize(); err != nil {
logging.Fatal().Err(err).Msg("Failed to initialize authentication providers")
}
} }
sig := make(chan os.Signal, 1) sig := make(chan os.Signal, 1)

View file

@ -22,9 +22,8 @@ func NewHandler(cfg config.ConfigInstance) http.Handler {
mux := ServeMux{http.NewServeMux()} mux := ServeMux{http.NewServeMux()}
mux.HandleFunc("GET", "/v1", v1.Index) mux.HandleFunc("GET", "/v1", v1.Index)
mux.HandleFunc("GET", "/v1/version", v1.GetVersion) mux.HandleFunc("GET", "/v1/version", v1.GetVersion)
mux.HandleFunc("POST", "/v1/login", auth.LoginHandler) mux.HandleFunc("POST", "/v1/login", auth.UserPassLoginHandler)
mux.HandleFunc("GET", "/v1/login/method", auth.AuthMethodHandler) mux.HandleFunc("GET", "/v1/auth/redirect", auth.AuthRedirectHandler)
mux.HandleFunc("GET", "/v1/login/oidc", auth.OIDCLoginHandler)
mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler) mux.HandleFunc("GET", "/v1/auth/callback", auth.OIDCCallbackHandler)
mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("GET", "/v1/logout", auth.LogoutHandler)
mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler) mux.HandleFunc("POST", "/v1/logout", auth.LogoutHandler)

View file

@ -1,8 +1,6 @@
package auth package auth
import ( import (
"bytes"
"encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"time" "time"
@ -25,51 +23,37 @@ type (
} }
) )
var ( // Initialize sets up authentication providers.
ErrInvalidUsername = E.New("invalid username") func Initialize() error {
ErrInvalidPassword = E.New("invalid password") // Initialize OIDC if configured.
) if common.OIDCIssuerURL != "" {
return InitOIDC(
func validatePassword(cred *Credentials) error { common.OIDCIssuerURL,
if cred.Username != common.APIUser { common.OIDCClientID,
return ErrInvalidUsername.Subject(cred.Username) common.OIDCClientSecret,
} common.OIDCRedirectURL,
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) { )
return ErrInvalidPassword.Subject(cred.Password)
} }
return nil return nil
} }
func LoginHandler(w http.ResponseWriter, r *http.Request) { func IsEnabled() bool {
var creds Credentials return common.APIJWTSecret != nil || common.OIDCIssuerURL != ""
err := json.NewDecoder(r.Body).Decode(&creds)
if err != nil {
U.HandleErr(w, r, err, http.StatusBadRequest)
return
}
if err := validatePassword(&creds); err != nil {
U.HandleErr(w, r, err, http.StatusUnauthorized)
return
}
if err := setAuthenticatedCookie(w, creds.Username); err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
} }
func AuthMethodHandler(w http.ResponseWriter, r *http.Request) { // AuthRedirectHandler handles redirect to login page or OIDC login base on configuration.
func AuthRedirectHandler(w http.ResponseWriter, r *http.Request) {
switch { switch {
case oauthConfig != nil:
RedirectOIDC(w, r)
return
case common.APIJWTSecret == nil: case common.APIJWTSecret == nil:
U.WriteBody(w, []byte("skip")) http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
case common.OIDCIssuerURL != "": return
U.WriteBody(w, []byte("oidc"))
case common.APIPasswordHash != nil:
U.WriteBody(w, []byte("password"))
default: default:
U.WriteBody(w, []byte("skip")) U.WriteBody(w, []byte("skip"))
w.WriteHeader(http.StatusOK)
} }
w.WriteHeader(http.StatusOK)
} }
func setAuthenticatedCookie(w http.ResponseWriter, username string) error { func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
@ -86,57 +70,44 @@ func setAuthenticatedCookie(w http.ResponseWriter, username string) error {
return err return err
} }
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: "token", Name: CookieToken,
Value: tokenStr, Value: tokenStr,
Expires: expiresAt, Expires: expiresAt,
HttpOnly: true, HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteStrictMode,
Path: "/", Path: "/",
}) })
return nil return nil
} }
// LogoutHandler clear authentication cookie and redirect to login page.
func LogoutHandler(w http.ResponseWriter, r *http.Request) { func LogoutHandler(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: "token", Name: CookieToken,
Value: "", Value: "",
Expires: time.Unix(0, 0), Expires: time.Unix(0, 0),
HttpOnly: true, HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode, SameSite: http.SameSiteStrictMode,
Path: "/", Path: "/",
}) })
w.Header().Set("location", "/login") AuthRedirectHandler(w, r)
w.WriteHeader(http.StatusTemporaryRedirect)
}
// Initialize sets up authentication providers.
func Initialize() error {
// Initialize OIDC if configured.
if common.OIDCIssuerURL != "" {
return InitOIDC(
common.OIDCIssuerURL,
common.OIDCClientID,
common.OIDCClientSecret,
common.OIDCRedirectURL,
)
}
return nil
} }
func RequireAuth(next http.HandlerFunc) http.HandlerFunc { func RequireAuth(next http.HandlerFunc) http.HandlerFunc {
if common.IsDebugSkipAuth || common.APIJWTSecret == nil { if IsEnabled() {
return next return func(w http.ResponseWriter, r *http.Request) {
} if checkToken(w, r) {
next(w, r)
return func(w http.ResponseWriter, r *http.Request) { }
if checkToken(w, r) {
next(w, r)
} }
} }
return next
} }
func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) { func checkToken(w http.ResponseWriter, r *http.Request) (ok bool) {
tokenCookie, err := r.Cookie("token") tokenCookie, err := r.Cookie(CookieToken)
if err != nil { if err != nil {
U.RespondError(w, E.New("missing token"), http.StatusUnauthorized) U.RespondError(w, E.New("missing token"), http.StatusUnauthorized)
return false return false

View file

@ -0,0 +1,6 @@
package auth
const (
CookieToken = "token"
CookieOauthState = "oauth_state"
)

View file

@ -4,10 +4,8 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"time"
"github.com/coreos/go-oidc/v3/oidc" "github.com/coreos/go-oidc/v3/oidc"
"github.com/golang-jwt/jwt/v5"
U "github.com/yusing/go-proxy/internal/api/v1/utils" U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common" "github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error" E "github.com/yusing/go-proxy/internal/error"
@ -47,8 +45,8 @@ func InitOIDC(issuerURL, clientID, clientSecret, redirectURL string) error {
return nil return nil
} }
// OIDCLoginHandler initiates the OIDC login flow. // RedirectOIDC initiates the OIDC login flow.
func OIDCLoginHandler(w http.ResponseWriter, r *http.Request) { func RedirectOIDC(w http.ResponseWriter, r *http.Request) {
if oauthConfig == nil { if oauthConfig == nil {
U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented) U.HandleErr(w, r, E.New("OIDC not configured"), http.StatusNotImplemented)
return return
@ -56,7 +54,7 @@ func OIDCLoginHandler(w http.ResponseWriter, r *http.Request) {
state := common.GenerateRandomString(32) state := common.GenerateRandomString(32)
http.SetCookie(w, &http.Cookie{ http.SetCookie(w, &http.Cookie{
Name: "oauth_state", Name: CookieOauthState,
Value: state, Value: state,
MaxAge: 300, MaxAge: 300,
HttpOnly: true, HttpOnly: true,
@ -87,7 +85,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
return return
} }
state, err := r.Cookie("oauth_state") state, err := r.Cookie(CookieOauthState)
if err != nil { if err != nil {
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest) U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
return return
@ -137,7 +135,7 @@ func OIDCCallbackHandler(w http.ResponseWriter, r *http.Request) {
// handleTestCallback handles OIDC callback in test environment. // handleTestCallback handles OIDC callback in test environment.
func handleTestCallback(w http.ResponseWriter, r *http.Request) { func handleTestCallback(w http.ResponseWriter, r *http.Request) {
state, err := r.Cookie("oauth_state") state, err := r.Cookie(CookieOauthState)
if err != nil { if err != nil {
U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest) U.HandleErr(w, r, E.New("missing state cookie"), http.StatusBadRequest)
return return
@ -149,29 +147,10 @@ func handleTestCallback(w http.ResponseWriter, r *http.Request) {
} }
// Create test JWT token // Create test JWT token
expiresAt := time.Now().Add(common.APIJWTTokenTTL) if err := setAuthenticatedCookie(w, "test-user"); err != nil {
jwtClaims := &Claims{
Username: "test-user",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(expiresAt),
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS512, jwtClaims)
tokenStr, err := token.SignedString(common.APIJWTSecret)
if err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError) U.HandleErr(w, r, err, http.StatusInternalServerError)
return return
} }
http.SetCookie(w, &http.Cookie{
Name: "token",
Value: tokenStr,
Expires: expiresAt,
HttpOnly: true,
SameSite: http.SameSiteStrictMode,
Path: "/",
})
http.Redirect(w, r, "/", http.StatusTemporaryRedirect) http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
} }

View file

@ -68,10 +68,10 @@ func TestOIDCLoginHandler(t *testing.T) {
oauthConfig = nil oauthConfig = nil
} }
req := httptest.NewRequest(http.MethodGet, "/login/oidc", nil) req := httptest.NewRequest(http.MethodGet, "/auth/redirect", nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
OIDCLoginHandler(w, req) RedirectOIDC(w, req)
if got := w.Code; got != tt.wantStatus { if got := w.Code; got != tt.wantStatus {
t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus) t.Errorf("OIDCLoginHandler() status = %v, want %v", got, tt.wantStatus)

View file

@ -0,0 +1,45 @@
package auth
import (
"bytes"
"encoding/json"
"net/http"
U "github.com/yusing/go-proxy/internal/api/v1/utils"
"github.com/yusing/go-proxy/internal/common"
E "github.com/yusing/go-proxy/internal/error"
)
var (
ErrInvalidUsername = E.New("invalid username")
ErrInvalidPassword = E.New("invalid password")
)
func validatePassword(cred *Credentials) error {
if cred.Username != common.APIUser {
return ErrInvalidUsername.Subject(cred.Username)
}
if !bytes.Equal(common.HashPassword(cred.Password), common.APIPasswordHash) {
return ErrInvalidPassword.Subject(cred.Password)
}
return nil
}
// UserPassLoginHandler handles user login.
func UserPassLoginHandler(w http.ResponseWriter, r *http.Request) {
var creds Credentials
err := json.NewDecoder(r.Body).Decode(&creds)
if err != nil {
U.HandleErr(w, r, err, http.StatusBadRequest)
return
}
if err := validatePassword(&creds); err != nil {
U.HandleErr(w, r, err, http.StatusUnauthorized)
return
}
if err := setAuthenticatedCookie(w, creds.Username); err != nil {
U.HandleErr(w, r, err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}

View file

@ -7,7 +7,7 @@ import (
"github.com/yusing/go-proxy/internal/utils/strutils/ansi" "github.com/yusing/go-proxy/internal/utils/strutils/ansi"
) )
// HandleErr logs the error and returns an HTTP error response to the client. // HandleErr logs the error and returns an error code to the client.
// If code is specified, it will be used as the HTTP status code; otherwise, // If code is specified, it will be used as the HTTP status code; otherwise,
// http.StatusInternalServerError is used. // http.StatusInternalServerError is used.
// //
@ -23,10 +23,14 @@ func HandleErr(w http.ResponseWriter, r *http.Request, err error, code ...int) {
http.Error(w, http.StatusText(code[0]), code[0]) http.Error(w, http.StatusText(code[0]), code[0])
} }
// RespondError returns error details to the client.
// If code is specified, it will be used as the HTTP status code; otherwise,
// http.StatusBadRequest is used.
func RespondError(w http.ResponseWriter, err error, code ...int) { func RespondError(w http.ResponseWriter, err error, code ...int) {
if len(code) == 0 { if len(code) == 0 {
code = []int{http.StatusBadRequest} code = []int{http.StatusBadRequest}
} }
// strip ANSI color codes added from Error.WithSubject
http.Error(w, ansi.StripANSI(err.Error()), code[0]) http.Error(w, ansi.StripANSI(err.Error()), code[0])
} }

View file

@ -11,7 +11,7 @@ import (
func WriteBody(w http.ResponseWriter, body []byte) { func WriteBody(w http.ResponseWriter, body []byte) {
if _, err := w.Write(body); err != nil { if _, err := w.Write(body); err != nil {
HandleErr(w, nil, err) logging.Err(err).Msg("failed to write body")
} }
} }

View file

@ -14,11 +14,10 @@ import (
var ( var (
prefixes = []string{"GODOXY_", "GOPROXY_", ""} prefixes = []string{"GODOXY_", "GOPROXY_", ""}
IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test") IsTest = GetEnvBool("TEST", false) || strings.HasSuffix(os.Args[0], ".test")
IsDebug = GetEnvBool("DEBUG", IsTest) IsDebug = GetEnvBool("DEBUG", IsTest)
IsDebugSkipAuth = GetEnvBool("DEBUG_SKIP_AUTH", false) IsTrace = GetEnvBool("TRACE", false) && IsDebug
IsTrace = GetEnvBool("TRACE", false) && IsDebug IsProduction = !IsTest && !IsDebug
IsProduction = !IsTest && !IsDebug
ProxyHTTPAddr, ProxyHTTPAddr,
ProxyHTTPHost, ProxyHTTPHost,
@ -46,7 +45,7 @@ var (
APIUser = GetEnvString("API_USER", "admin") APIUser = GetEnvString("API_USER", "admin")
APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password")) APIPasswordHash = HashPassword(GetEnvString("API_PASSWORD", "password"))
// OIDC Configuration // OIDC Configuration.
OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "") OIDCIssuerURL = GetEnvString("OIDC_ISSUER_URL", "")
OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "") OIDCClientID = GetEnvString("OIDC_CLIENT_ID", "")
OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "") OIDCClientSecret = GetEnvString("OIDC_CLIENT_SECRET", "")

View file

@ -31,6 +31,7 @@ func makeEntries(cont *types.Container, dockerHostIP ...string) route.RawEntries
} else { } else {
host = client.DefaultDockerHost host = client.DefaultDockerHost
} }
p.name = "test"
entries := E.Must(p.entriesFromContainerLabels(D.FromDocker(cont, host))) entries := E.Must(p.entriesFromContainerLabels(D.FromDocker(cont, host)))
entries.RangeAll(func(k string, v *route.RawEntry) { entries.RangeAll(func(k string, v *route.RawEntry) {
v.Finalize() v.Finalize()